diff --git a/src/algorithm/dft.rs b/src/algorithm/dft.rs index 4f8443be..e0b700cf 100644 --- a/src/algorithm/dft.rs +++ b/src/algorithm/dft.rs @@ -38,6 +38,13 @@ impl Dft { } } + fn inplace_scratch_len(&self) -> usize { + self.len() + } + fn outofplace_scratch_len(&self) -> usize { + 0 + } + fn perform_fft_out_of_place( &self, signal: &[Complex], diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 573c3622..6299f18f 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -5,6 +5,7 @@ mod mixed_radix; mod raders_algorithm; mod radix3; mod radix4; +mod radixn; /// Hardcoded size-specfic FFT algorithms pub mod butterflies; @@ -16,3 +17,4 @@ pub use self::mixed_radix::{MixedRadix, MixedRadixSmall}; pub use self::raders_algorithm::RadersAlgorithm; pub use self::radix3::Radix3; pub use self::radix4::Radix4; +pub use self::radixn::RadixN; diff --git a/src/algorithm/radix3.rs b/src/algorithm/radix3.rs index 013b660f..d392f8c1 100644 --- a/src/algorithm/radix3.rs +++ b/src/algorithm/radix3.rs @@ -1,9 +1,9 @@ use std::sync::Arc; use num_complex::Complex; -use num_traits::Zero; use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9}; +use crate::algorithm::radixn::butterfly_3; use crate::array_utils::{self, bitreversed_transpose, compute_logarithm}; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; @@ -32,6 +32,8 @@ pub struct Radix3 { len: usize, direction: FftDirection, + inplace_scratch_len: usize, + outofplace_scratch_len: usize, } impl Radix3 { @@ -68,10 +70,11 @@ impl Radix3 { // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up const ROW_COUNT: usize = 3; - let mut cross_fft_len = base_len * ROW_COUNT; + let mut cross_fft_len = base_len; let mut twiddle_factors = Vec::with_capacity(len * 2); - while cross_fft_len <= len { - let num_columns = cross_fft_len / ROW_COUNT; + while cross_fft_len < len { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; for i in 0..num_columns { for k in 1..ROW_COUNT { @@ -79,9 +82,20 @@ impl Radix3 { twiddle_factors.push(twiddle); } } - cross_fft_len *= ROW_COUNT; } + let base_inplace_scratch = base_fft.get_inplace_scratch_len(); + let inplace_scratch_len = if base_inplace_scratch > cross_fft_len { + cross_fft_len + base_inplace_scratch + } else { + cross_fft_len + }; + let outofplace_scratch_len = if base_inplace_scratch > len { + base_inplace_scratch + } else { + 0 + }; + Self { twiddles: twiddle_factors.into_boxed_slice(), butterfly3: Butterfly3::new(direction), @@ -91,14 +105,24 @@ impl Radix3 { len, direction, + + inplace_scratch_len, + outofplace_scratch_len, } } + fn inplace_scratch_len(&self) -> usize { + self.inplace_scratch_len + } + fn outofplace_scratch_len(&self) -> usize { + self.outofplace_scratch_len + } + fn perform_fft_out_of_place( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - _scratch: &mut [Complex], + scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -108,63 +132,30 @@ impl Radix3 { } // Base-level FFTs - self.base_fft.process_with_scratch(output, &mut []); + let base_scratch = if scratch.len() > 0 { scratch } else { input }; + self.base_fft.process_with_scratch(output, base_scratch); // cross-FFTs const ROW_COUNT: usize = 3; - let mut cross_fft_len = self.base_len * ROW_COUNT; + let mut cross_fft_len = self.base_len; let mut layer_twiddles: &[Complex] = &self.twiddles; - while cross_fft_len <= input.len() { - let num_rows = input.len() / cross_fft_len; - let num_columns = cross_fft_len / ROW_COUNT; - - for i in 0..num_rows { - unsafe { - butterfly_3( - &mut output[i * cross_fft_len..], - layer_twiddles, - num_columns, - &self.butterfly3, - ) - } + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) } } // skip past all the twiddle factors used in this layer let twiddle_offset = num_columns * (ROW_COUNT - 1); layer_twiddles = &layer_twiddles[twiddle_offset..]; - - cross_fft_len *= ROW_COUNT; } } } boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len); -unsafe fn butterfly_3( - data: &mut [Complex], - twiddles: &[Complex], - num_ffts: usize, - butterfly3: &Butterfly3, -) { - let mut idx = 0usize; - let mut tw_idx = 0usize; - let mut scratch = [Zero::zero(); 3]; - for _ in 0..num_ffts { - scratch[0] = *data.get_unchecked(idx); - scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx]; - scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1]; - - butterfly3.perform_fft_butterfly(&mut scratch); - - *data.get_unchecked_mut(idx) = scratch[0]; - *data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1]; - *data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2]; - - tw_idx += 2; - idx += 1; - } -} - #[cfg(test)] mod unit_tests { use super::*; diff --git a/src/algorithm/radix4.rs b/src/algorithm/radix4.rs index 40140b01..33a804e4 100644 --- a/src/algorithm/radix4.rs +++ b/src/algorithm/radix4.rs @@ -1,11 +1,11 @@ use std::sync::Arc; use num_complex::Complex; -use num_traits::Zero; use crate::algorithm::butterflies::{ Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8, }; +use crate::algorithm::radixn::butterfly_4; use crate::array_utils::{self, bitreversed_transpose}; use crate::common::{fft_error_inplace, fft_error_outofplace}; use crate::{common::FftNum, twiddles, FftDirection}; @@ -33,6 +33,8 @@ pub struct Radix4 { len: usize, direction: FftDirection, + inplace_scratch_len: usize, + outofplace_scratch_len: usize, } impl Radix4 { @@ -75,10 +77,11 @@ impl Radix4 { // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up const ROW_COUNT: usize = 4; - let mut cross_fft_len = base_len * ROW_COUNT; + let mut cross_fft_len = base_len; let mut twiddle_factors = Vec::with_capacity(len * 2); - while cross_fft_len <= len { - let num_columns = cross_fft_len / ROW_COUNT; + while cross_fft_len < len { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; for i in 0..num_columns { for k in 1..ROW_COUNT { @@ -86,9 +89,20 @@ impl Radix4 { twiddle_factors.push(twiddle); } } - cross_fft_len *= ROW_COUNT; } + let base_inplace_scratch = base_fft.get_inplace_scratch_len(); + let inplace_scratch_len = if base_inplace_scratch > cross_fft_len { + cross_fft_len + base_inplace_scratch + } else { + cross_fft_len + }; + let outofplace_scratch_len = if base_inplace_scratch > len { + base_inplace_scratch + } else { + 0 + }; + Self { twiddles: twiddle_factors.into_boxed_slice(), @@ -97,14 +111,24 @@ impl Radix4 { len, direction, + + inplace_scratch_len, + outofplace_scratch_len, } } + fn inplace_scratch_len(&self) -> usize { + self.inplace_scratch_len + } + fn outofplace_scratch_len(&self) -> usize { + self.outofplace_scratch_len + } + fn perform_fft_out_of_place( &self, - input: &[Complex], + input: &mut [Complex], output: &mut [Complex], - _scratch: &mut [Complex], + scratch: &mut [Complex], ) { // copy the data into the output vector if self.len() == self.base_len { @@ -114,67 +138,32 @@ impl Radix4 { } // Base-level FFTs - self.base_fft.process_with_scratch(output, &mut []); + let base_scratch = if scratch.len() > 0 { scratch } else { input }; + self.base_fft.process_with_scratch(output, base_scratch); // cross-FFTs const ROW_COUNT: usize = 4; - let mut cross_fft_len = self.base_len * ROW_COUNT; + let mut cross_fft_len = self.base_len; let mut layer_twiddles: &[Complex] = &self.twiddles; - while cross_fft_len <= input.len() { - let num_rows = input.len() / cross_fft_len; - let num_columns = cross_fft_len / ROW_COUNT; - - for i in 0..num_rows { - unsafe { - butterfly_4( - &mut output[i * cross_fft_len..], - layer_twiddles, - num_columns, - self.direction, - ) - } + let butterfly4 = Butterfly4::new(self.direction); + + while cross_fft_len < output.len() { + let num_columns = cross_fft_len; + cross_fft_len *= ROW_COUNT; + + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) } } // skip past all the twiddle factors used in this layer let twiddle_offset = num_columns * (ROW_COUNT - 1); layer_twiddles = &layer_twiddles[twiddle_offset..]; - - cross_fft_len *= ROW_COUNT; } } } boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len); -unsafe fn butterfly_4( - data: &mut [Complex], - twiddles: &[Complex], - num_ffts: usize, - direction: FftDirection, -) { - let butterfly4 = Butterfly4::new(direction); - - let mut idx = 0usize; - let mut tw_idx = 0usize; - let mut scratch = [Zero::zero(); 4]; - for _ in 0..num_ffts { - scratch[0] = *data.get_unchecked(idx); - scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx]; - scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1]; - scratch[3] = *data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2]; - - butterfly4.perform_fft_butterfly(&mut scratch); - - *data.get_unchecked_mut(idx) = scratch[0]; - *data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1]; - *data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2]; - *data.get_unchecked_mut(idx + 3 * num_ffts) = scratch[3]; - - tw_idx += 3; - idx += 1; - } -} - #[cfg(test)] mod unit_tests { use super::*; diff --git a/src/algorithm/radixn.rs b/src/algorithm/radixn.rs new file mode 100644 index 00000000..bd4d35ca --- /dev/null +++ b/src/algorithm/radixn.rs @@ -0,0 +1,468 @@ +use std::sync::Arc; + +use num_complex::Complex; + +use crate::array_utils::{self, factor_transpose, Load, LoadStore, TransposeFactor}; +use crate::common::{fft_error_inplace, fft_error_outofplace}; +use crate::{common::FftNum, twiddles, FftDirection}; +use crate::{Direction, Fft, Length, RadixFactor}; + +use super::butterflies::{Butterfly2, Butterfly3, Butterfly4, Butterfly5, Butterfly6, Butterfly7}; + +#[repr(u8)] +enum InternalRadixFactor { + Factor2(Butterfly2), + Factor3(Butterfly3), + Factor4(Butterfly4), + Factor5(Butterfly5), + Factor6(Butterfly6), + Factor7(Butterfly7), +} +impl InternalRadixFactor { + pub const fn radix(&self) -> usize { + // note: if we had rustc 1.66, we could just turn these values explicit discriminators on the enum + match self { + InternalRadixFactor::Factor2(_) => 2, + InternalRadixFactor::Factor3(_) => 3, + InternalRadixFactor::Factor4(_) => 4, + InternalRadixFactor::Factor5(_) => 5, + InternalRadixFactor::Factor6(_) => 6, + InternalRadixFactor::Factor7(_) => 7, + } + } +} + +/// FFT algorithm which efficiently computes FFTs with small prime factors. +/// +/// ~~~ +/// // Computes a forward FFT of size 6720 (32 * 7 * 5 * 3 * 2) +/// use std::sync::Arc; +/// use rustfft::algorithm::{RadixN, butterflies::Butterfly32}; +/// use rustfft::{Fft, FftDirection, RadixFactor}; +/// use rustfft::num_complex::Complex; +/// +/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 6720]; +/// +/// let base_fft = Arc::new(Butterfly32::new(FftDirection::Forward)); +/// let factors = &[RadixFactor::Factor7, RadixFactor::Factor5, RadixFactor::Factor3, RadixFactor::Factor2]; +/// let fft = RadixN::new(factors, base_fft); +/// fft.process(&mut buffer); +/// ~~~ +pub struct RadixN { + twiddles: Box<[Complex]>, + + base_fft: Arc>, + base_len: usize, + + factors: Box<[TransposeFactor]>, + butterflies: Box<[InternalRadixFactor]>, + + len: usize, + direction: FftDirection, + inplace_scratch_len: usize, + outofplace_scratch_len: usize, +} + +impl RadixN { + /// Constructs a RadixN instance which computes FFTs of length `factor_product * base_fft.len()` + pub fn new(factors: &[RadixFactor], base_fft: Arc>) -> Self { + let base_len = base_fft.len(); + let direction = base_fft.fft_direction(); + + // set up our cross FFT butterfly instances. simultaneously, compute the number of twiddle factors + let mut butterflies = Vec::with_capacity(factors.len()); + let mut cross_fft_len = base_len; + let mut twiddle_count = 0; + + for factor in factors { + // compute how many twiddles this cross-FFT needs + let cross_fft_rows = factor.radix(); + let cross_fft_columns = cross_fft_len; + + twiddle_count += cross_fft_columns * (cross_fft_rows - 1); + + // set up the butterfly for this cross-FFT + let butterfly = match factor { + RadixFactor::Factor2 => InternalRadixFactor::Factor2(Butterfly2::new(direction)), + RadixFactor::Factor3 => InternalRadixFactor::Factor3(Butterfly3::new(direction)), + RadixFactor::Factor4 => InternalRadixFactor::Factor4(Butterfly4::new(direction)), + RadixFactor::Factor5 => InternalRadixFactor::Factor5(Butterfly5::new(direction)), + RadixFactor::Factor6 => InternalRadixFactor::Factor6(Butterfly6::new(direction)), + RadixFactor::Factor7 => InternalRadixFactor::Factor7(Butterfly7::new(direction)), + }; + butterflies.push(butterfly); + + cross_fft_len *= cross_fft_rows; + } + let len = cross_fft_len; + + // set up our list of transpose factors - it's the same list but reversed, and we want to collapse duplicates + // Note that we are only de-duplicating adjacent factors. If we're passed 7 * 2 * 7, we can't collapse the sevens + // because the exact order of factors is is important for the transpose + let mut transpose_factors: Vec = Vec::with_capacity(factors.len()); + for f in factors.iter().rev() { + // I really want let chains for this! + let mut push_new = true; + if let Some(last) = transpose_factors.last_mut() { + if last.factor == *f { + last.count += 1; + push_new = false; + } + } + if push_new { + transpose_factors.push(TransposeFactor { + factor: *f, + count: 1, + }); + } + } + + // precompute the twiddle factors this algorithm will use. + // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=factor.radix() and height=len/factor.radix() + // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down + // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up + let mut cross_fft_len = base_len; + let mut twiddle_factors = Vec::with_capacity(twiddle_count); + + for factor in factors { + // Compute the twiddle factors for the cross FFT + let cross_fft_columns = cross_fft_len; + cross_fft_len *= factor.radix(); + + for i in 0..cross_fft_columns { + for k in 1..factor.radix() { + let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction); + twiddle_factors.push(twiddle); + } + } + } + + // figure out how much scratch space we need to request from callers + let base_inplace_scratch = base_fft.get_inplace_scratch_len(); + let inplace_scratch_len = if base_inplace_scratch > len { + len + base_inplace_scratch + } else { + len + }; + let outofplace_scratch_len = if base_inplace_scratch > len { + base_inplace_scratch + } else { + 0 + }; + + Self { + twiddles: twiddle_factors.into_boxed_slice(), + + base_fft, + base_len, + + factors: transpose_factors.into_boxed_slice(), + butterflies: butterflies.into_boxed_slice(), + + len, + direction, + + inplace_scratch_len, + outofplace_scratch_len, + } + } + + fn inplace_scratch_len(&self) -> usize { + self.inplace_scratch_len + } + fn outofplace_scratch_len(&self) -> usize { + self.outofplace_scratch_len + } + + fn perform_fft_out_of_place( + &self, + input: &mut [Complex], + output: &mut [Complex], + scratch: &mut [Complex], + ) { + if let Some(unroll_factor) = self.factors.first() { + // for performance, we really, really want to unroll the transpose, but we need to make sure the output length is divisible by the unroll amount + // choosing the first factor seems to reliably perform well + match unroll_factor.factor { + RadixFactor::Factor2 => { + factor_transpose::, 2>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor3 => { + factor_transpose::, 3>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor4 => { + factor_transpose::, 4>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor5 => { + factor_transpose::, 5>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor6 => { + factor_transpose::, 6>(self.base_len, input, output, &self.factors) + } + RadixFactor::Factor7 => { + factor_transpose::, 7>(self.base_len, input, output, &self.factors) + } + } + } else { + // no factors, so just pass data straight to our base + output.copy_from_slice(input); + } + + // Base-level FFTs + let base_scratch = if scratch.len() > 0 { scratch } else { input }; + self.base_fft.process_with_scratch(output, base_scratch); + + // cross-FFTs + let mut cross_fft_len = self.base_len; + let mut layer_twiddles: &[Complex] = &self.twiddles; + + for factor in self.butterflies.iter() { + let cross_fft_columns = cross_fft_len; + cross_fft_len *= factor.radix(); + + match factor { + InternalRadixFactor::Factor2(butterfly2) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_2(data, layer_twiddles, cross_fft_columns, butterfly2) } + } + } + InternalRadixFactor::Factor3(butterfly3) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_3(data, layer_twiddles, cross_fft_columns, butterfly3) } + } + } + InternalRadixFactor::Factor4(butterfly4) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_4(data, layer_twiddles, cross_fft_columns, butterfly4) } + } + } + InternalRadixFactor::Factor5(butterfly5) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_5(data, layer_twiddles, cross_fft_columns, butterfly5) } + } + } + InternalRadixFactor::Factor6(butterfly6) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_6(data, layer_twiddles, cross_fft_columns, butterfly6) } + } + } + InternalRadixFactor::Factor7(butterfly7) => { + for data in output.chunks_exact_mut(cross_fft_len) { + unsafe { butterfly_7(data, layer_twiddles, cross_fft_columns, butterfly7) } + } + } + } + + // skip past all the twiddle factors used in this layer + let twiddle_offset = cross_fft_columns * (factor.radix() - 1); + layer_twiddles = &layer_twiddles[twiddle_offset..]; + } + } +} +boilerplate_fft_oop!(RadixN, |this: &RadixN<_>| this.len); + +#[inline(never)] +pub(crate) unsafe fn butterfly_2( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly2: &Butterfly2, +) { + for idx in 0..num_columns { + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(idx), + ]; + + butterfly2.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + num_columns * 0); + data.store(scratch[1], idx + num_columns * 1); + } +} + +#[inline(never)] +pub(crate) unsafe fn butterfly_3( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly3: &Butterfly3, +) { + for idx in 0..num_columns { + let tw_idx = idx * 2; + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0), + data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1), + ]; + + butterfly3.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + 0 * num_columns); + data.store(scratch[1], idx + 1 * num_columns); + data.store(scratch[2], idx + 2 * num_columns); + } +} + +#[inline(never)] +pub(crate) unsafe fn butterfly_4( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly4: &Butterfly4, +) { + for idx in 0..num_columns { + let tw_idx = idx * 3; + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0), + data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1), + data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2), + ]; + + butterfly4.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + 0 * num_columns); + data.store(scratch[1], idx + 1 * num_columns); + data.store(scratch[2], idx + 2 * num_columns); + data.store(scratch[3], idx + 3 * num_columns); + } +} + +#[inline(never)] +pub(crate) unsafe fn butterfly_5( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly5: &Butterfly5, +) { + for idx in 0..num_columns { + let tw_idx = idx * 4; + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0), + data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1), + data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2), + data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3), + ]; + + butterfly5.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + 0 * num_columns); + data.store(scratch[1], idx + 1 * num_columns); + data.store(scratch[2], idx + 2 * num_columns); + data.store(scratch[3], idx + 3 * num_columns); + data.store(scratch[4], idx + 4 * num_columns); + } +} + +#[inline(never)] +pub(crate) unsafe fn butterfly_6( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly6: &Butterfly6, +) { + for idx in 0..num_columns { + let tw_idx = idx * 5; + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0), + data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1), + data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2), + data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3), + data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4), + ]; + + butterfly6.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + 0 * num_columns); + data.store(scratch[1], idx + 1 * num_columns); + data.store(scratch[2], idx + 2 * num_columns); + data.store(scratch[3], idx + 3 * num_columns); + data.store(scratch[4], idx + 4 * num_columns); + data.store(scratch[5], idx + 5 * num_columns); + } +} + +#[inline(never)] +pub(crate) unsafe fn butterfly_7( + mut data: impl LoadStore, + twiddles: impl Load, + num_columns: usize, + butterfly7: &Butterfly7, +) { + for idx in 0..num_columns { + let tw_idx = idx * 6; + let mut scratch = [ + data.load(idx + 0 * num_columns), + data.load(idx + 1 * num_columns) * twiddles.load(tw_idx + 0), + data.load(idx + 2 * num_columns) * twiddles.load(tw_idx + 1), + data.load(idx + 3 * num_columns) * twiddles.load(tw_idx + 2), + data.load(idx + 4 * num_columns) * twiddles.load(tw_idx + 3), + data.load(idx + 5 * num_columns) * twiddles.load(tw_idx + 4), + data.load(idx + 6 * num_columns) * twiddles.load(tw_idx + 5), + ]; + + butterfly7.perform_fft_butterfly(&mut scratch); + + data.store(scratch[0], idx + 0 * num_columns); + data.store(scratch[1], idx + 1 * num_columns); + data.store(scratch[2], idx + 2 * num_columns); + data.store(scratch[3], idx + 3 * num_columns); + data.store(scratch[4], idx + 4 * num_columns); + data.store(scratch[5], idx + 5 * num_columns); + data.store(scratch[6], idx + 6 * num_columns); + } +} + +#[cfg(test)] +mod unit_tests { + use super::*; + use crate::test_utils::{check_fft_algorithm, construct_base}; + + #[test] + fn test_scalar_radixn() { + let factor_list = &[ + RadixFactor::Factor2, + RadixFactor::Factor3, + RadixFactor::Factor4, + RadixFactor::Factor5, + RadixFactor::Factor6, + RadixFactor::Factor7, + ]; + + for base in 1..7 { + let base_forward = construct_base(base, FftDirection::Forward); + let base_inverse = construct_base(base, FftDirection::Inverse); + + // test just the base with no factors + test_radixn(&[], Arc::clone(&base_forward)); + test_radixn(&[], Arc::clone(&base_inverse)); + + // test one factor + for factor_a in factor_list { + let factors = &[*factor_a]; + test_radixn(factors, Arc::clone(&base_forward)); + test_radixn(factors, Arc::clone(&base_inverse)); + } + + // test two factors + for factor_a in factor_list { + for factor_b in factor_list { + let factors = &[*factor_a, *factor_b]; + test_radixn(factors, Arc::clone(&base_forward)); + test_radixn(factors, Arc::clone(&base_inverse)); + } + } + } + } + + fn test_radixn(factors: &[RadixFactor], base_fft: Arc>) { + let len = base_fft.len() * factors.iter().map(|f| f.radix()).product::(); + let direction = base_fft.fft_direction(); + let fft = RadixN::new(factors, base_fft); + + check_fft_algorithm::(&fft, len, direction); + } +} diff --git a/src/array_utils.rs b/src/array_utils.rs index 57c09f21..8dd0aba6 100644 --- a/src/array_utils.rs +++ b/src/array_utils.rs @@ -1,5 +1,6 @@ use crate::Complex; use crate::FftNum; +use crate::RadixFactor; use std::ops::{Deref, DerefMut}; /// Given an array of size width * height, representing a flattened 2D array, @@ -87,6 +88,25 @@ impl<'a, T: FftNum> LoadStore for DoubleBuf<'a, T> { } } +pub(crate) trait Load: Deref { + unsafe fn load(&self, idx: usize) -> Complex; +} + +impl Load for &[Complex] { + #[inline(always)] + unsafe fn load(&self, idx: usize) -> Complex { + debug_assert!(idx < self.len()); + *self.get_unchecked(idx) + } +} +impl Load for &[Complex; N] { + #[inline(always)] + unsafe fn load(&self, idx: usize) -> Complex { + debug_assert!(idx < self.len()); + *self.get_unchecked(idx) + } +} + #[cfg(test)] mod unit_tests { use super::*; @@ -217,7 +237,7 @@ pub fn bitreversed_transpose( i += 1; value }); // If we had access to rustc 1.63, we could use std::array::from_fn instead - let x_rev = x_fwd.map(|x| reverse_remainders::(x, rev_digits)); + let x_rev = x_fwd.map(|x| reverse_bits::(x, rev_digits)); // Assert that the the bit reversed indices will not exceed the length of the output. // The highest index the loop reaches is: (x_rev[n] + 1)*height - 1 @@ -243,7 +263,7 @@ pub fn bitreversed_transpose( // Repeatedly divide `value` by divisor `D`, `iters` times, and apply the remainders to a new value // When D is a power of 2, this is exactly equal (implementation and assembly)-wise to a bit reversal // When D is not a power of 2, think of this function as a logical equivalent to a bit reversal -pub fn reverse_remainders(value: usize, rev_digits: u32) -> usize { +pub fn reverse_bits(value: usize, rev_digits: u32) -> usize { assert!(D > 1); let mut result: usize = 0; @@ -275,3 +295,103 @@ pub fn compute_logarithm(value: usize) -> Option { None } } + +pub struct TransposeFactor { + pub factor: RadixFactor, + pub count: u8, +} + +// Utility to help reorder data as a part of computing RadixD FFTs. Conceputally, it works like a transpose, but with the column indexes bit-reversed. +// Use a lookup table to avoid repeating the slow bit reverse operations. +// Unrolling the outer loop by a factor D helps speed things up. +// const parameter D (for Divisor) determines how much to unroll. `input.len() / height` must divisible by D. +pub fn factor_transpose( + height: usize, + input: &[T], + output: &mut [T], + factors: &[TransposeFactor], +) { + let width = input.len() / height; + + // Let's make sure the arguments are ok + assert!(width % D == 0 && D > 1 && input.len() % width == 0 && input.len() == output.len()); + + let strided_width = width / D; + for x in 0..strided_width { + let mut i = 0; + let x_fwd = [(); D].map(|_| { + let value = D * x + i; + i += 1; + value + }); // If we had access to rustc 1.63, we could use std::array::from_fn instead + let x_rev = x_fwd.map(|x| reverse_remainders(x, factors)); + + // Assert that the the bit reversed indices will not exceed the length of the output. + // The highest index the loop reaches is: (x_rev[n] + 1)*height - 1 + // The last element of the data is at index: width*height - 1 + // Thus it is sufficient to assert that x_rev[n] usize { + let mut result: usize = 0; + let mut value = value; + for f in factors.iter() { + match f.factor { + RadixFactor::Factor2 => { + for _ in 0..f.count { + result = (result * 2) + (value % 2); + value = value / 2; + } + } + RadixFactor::Factor3 => { + for _ in 0..f.count { + result = (result * 3) + (value % 3); + value = value / 3; + } + } + RadixFactor::Factor4 => { + for _ in 0..f.count { + result = (result * 4) + (value % 4); + value = value / 4; + } + } + RadixFactor::Factor5 => { + for _ in 0..f.count { + result = (result * 5) + (value % 5); + value = value / 5; + } + } + RadixFactor::Factor6 => { + for _ in 0..f.count { + result = (result * 6) + (value % 6); + value = value / 6; + } + } + RadixFactor::Factor7 => { + for _ in 0..f.count { + result = (result * 7) + (value % 7); + value = value / 7; + } + } + } + } + result +} diff --git a/src/common.rs b/src/common.rs index 8be13267..008be6db 100644 --- a/src/common.rs +++ b/src/common.rs @@ -77,15 +77,25 @@ macro_rules! boilerplate_fft_oop { &self, input: &mut [Complex], output: &mut [Complex], - _scratch: &mut [Complex], + scratch: &mut [Complex], ) { if self.len() == 0 { return; } - if input.len() < self.len() || output.len() != input.len() { + let required_scratch = self.get_outofplace_scratch_len(); + if input.len() < self.len() + || output.len() != input.len() + || scratch.len() < required_scratch + { // We want to trigger a panic, but we want to avoid doing it in this function to reduce code size, so call a function marked cold and inline(never) that will do it for us - fft_error_outofplace(self.len(), input.len(), output.len(), 0, 0); + fft_error_outofplace( + self.len(), + input.len(), + output.len(), + required_scratch, + scratch.len(), + ); return; // Unreachable, because fft_error_outofplace asserts, but it helps codegen to put it here } @@ -94,7 +104,7 @@ macro_rules! boilerplate_fft_oop { output, self.len(), |in_chunk, out_chunk| { - self.perform_fft_out_of_place(in_chunk, out_chunk, &mut []) + self.perform_fft_out_of_place(in_chunk, out_chunk, scratch) }, ); @@ -121,9 +131,9 @@ macro_rules! boilerplate_fft_oop { return; // Unreachable, because fft_error_inplace asserts, but it helps codegen to put it here } - let scratch = &mut scratch[..required_scratch]; + let (scratch, extra_scratch) = scratch.split_at_mut(self.len()); let result = array_utils::iter_chunks(buffer, self.len(), |chunk| { - self.perform_fft_out_of_place(chunk, scratch, &mut []); + self.perform_fft_out_of_place(chunk, scratch, extra_scratch); chunk.copy_from_slice(scratch); }); @@ -140,11 +150,11 @@ macro_rules! boilerplate_fft_oop { } #[inline(always)] fn get_inplace_scratch_len(&self) -> usize { - self.len() + self.inplace_scratch_len() } #[inline(always)] fn get_outofplace_scratch_len(&self) -> usize { - 0 + self.outofplace_scratch_len() } } impl Length for $struct_name { @@ -269,3 +279,28 @@ macro_rules! boilerplate_fft { } }; } + +#[non_exhaustive] +#[repr(u8)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum RadixFactor { + Factor2, + Factor3, + Factor4, + Factor5, + Factor6, + Factor7, +} +impl RadixFactor { + pub const fn radix(&self) -> usize { + // note: if we had rustc 1.66, we could just turn these values explicit discriminators on the enum + match self { + RadixFactor::Factor2 => 2, + RadixFactor::Factor3 => 3, + RadixFactor::Factor4 => 4, + RadixFactor::Factor5 => 5, + RadixFactor::Factor6 => 6, + RadixFactor::Factor7 => 7, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 8eed0c32..74b47876 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ mod twiddles; use num_complex::Complex; use num_traits::Zero; -pub use crate::common::FftNum; +pub use crate::common::{FftNum, RadixFactor}; pub use crate::plan::{FftPlanner, FftPlannerScalar}; /// A trait that allows FFT algorithms to report their expected input/output size diff --git a/src/math_utils.rs b/src/math_utils.rs index 1c76cdc9..57d1cd2f 100644 --- a/src/math_utils.rs +++ b/src/math_utils.rs @@ -185,12 +185,13 @@ impl PrimeFactors { pub fn get_other_factors(&self) -> &[PrimeFactor] { &self.other_factors } - + #[allow(unused)] pub fn is_power_of_three(&self) -> bool { self.power_three > 0 && self.power_two == 0 && self.other_factors.len() == 0 } // Divides the number by the given prime factor. Returns None if the resulting number is one. + #[allow(unused)] pub fn remove_factors(mut self, factor: PrimeFactor) -> Option { if factor.count == 0 { return Some(self); @@ -235,6 +236,35 @@ impl PrimeFactors { None } + // returns true if we have any factors whose value is less than or equal to the provided factor + pub fn has_factors_leq(&self, factor: usize) -> bool { + self.power_two > 0 + || self.power_three > 0 + || self + .other_factors + .first() + .map_or(false, |f| f.value <= factor) + } + + // returns true if we have any factors whose value is greater than the provided factor + pub fn has_factors_gt(&self, factor: usize) -> bool { + (factor < 2 && self.power_two > 0) + || (factor < 3 && self.power_three > 0) + || self + .other_factors + .last() + .map_or(false, |f| f.value > factor) + } + + // returns the product of all factors greater than the provided min_factor + pub fn product_above(&self, min_factor: usize) -> usize { + self.other_factors + .iter() + .skip_while(|f| f.value <= min_factor) + .map(|f| f.value.pow(f.count)) + .product() + } + // Splits this set of prime factors into two different sets so that the products of the two sets are as close as possible pub fn partition_factors(mut self) -> (Self, Self) { // Make sure this isn't a prime number diff --git a/src/plan.rs b/src/plan.rs index 28500ed6..17a16203 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -7,13 +7,13 @@ use crate::{common::FftNum, fft_cache::FftCache, FftDirection}; use crate::algorithm::butterflies::*; use crate::algorithm::*; -use crate::Fft; +use crate::{Fft, RadixFactor}; use crate::FftPlannerAvx; use crate::FftPlannerNeon; use crate::FftPlannerSse; -use crate::math_utils::{PrimeFactor, PrimeFactors}; +use crate::math_utils::PrimeFactors; enum ChosenFftPlanner { Scalar(FftPlannerScalar), @@ -124,8 +124,7 @@ impl FftPlanner { } } -const MIN_RADIX4_BITS: u32 = 5; // smallest size to consider radix 4 an option is 2^5 = 32 -const MIN_RADIX3_FACTORS: u32 = 4; // smallest number of factors of 3 to consider radix 4 an option is 3^4=81. any smaller and we want to use butterflies directly. +const MAX_RADIXN_FACTOR: usize = 7; // The largest blutterfly factor that the RadixN algorithm can handle const MAX_RADER_PRIME_FACTOR: usize = 23; // don't use Raders if the inner fft length has prime factor larger than this /// A Recipe is a structure that describes the design of a FFT, without actually creating it. @@ -157,8 +156,8 @@ pub enum Recipe { len: usize, inner_fft: Arc, }, - Radix3 { - k: u32, + RadixN { + factors: Box<[RadixFactor]>, base_fft: Arc, }, Radix4 { @@ -191,7 +190,9 @@ impl Recipe { pub fn len(&self) -> usize { match self { Recipe::Dft(length) => *length, - Recipe::Radix3 { k, base_fft } => base_fft.len() * 3usize.pow(*k), + Recipe::RadixN { factors, base_fft } => { + base_fft.len() * factors.iter().map(|f| f.radix()).product::() + } Recipe::Radix4 { k, base_fft } => base_fft.len() * (1 << (k * 2)), Recipe::Butterfly2 => 2, Recipe::Butterfly3 => 3, @@ -336,9 +337,9 @@ impl FftPlannerScalar { fn build_new_fft(&mut self, recipe: &Recipe, direction: FftDirection) -> Arc> { match recipe { Recipe::Dft(len) => Arc::new(Dft::new(*len, direction)) as Arc>, - Recipe::Radix3 { k, base_fft } => { + Recipe::RadixN { factors, base_fft } => { let base_fft = self.build_fft(base_fft, direction); - Arc::new(Radix3::new_with_base(*k, base_fft)) as Arc> + Arc::new(RadixN::new(factors, base_fft)) as Arc> } Recipe::Radix4 { k, base_fft } => { let base_fft = self.build_fft(base_fft, direction); @@ -412,39 +413,63 @@ impl FftPlannerScalar { fft_instance } else if factors.is_prime() { self.design_prime(len) - } else if len.trailing_zeros() >= MIN_RADIX4_BITS { - if factors.get_other_factors().is_empty() && factors.get_power_of_three() < 2 { - self.design_radix4(factors) - } else { - let non_power_of_two = factors - .remove_factors(PrimeFactor { - value: 2, - count: len.trailing_zeros(), - }) - .unwrap(); - let power_of_two = PrimeFactors::compute(1 << len.trailing_zeros()); - self.design_mixed_radix(power_of_two, non_power_of_two) - } - } else if factors.get_power_of_three() >= MIN_RADIX3_FACTORS { - if factors.is_power_of_three() { - self.design_radix3(factors.get_power_of_three()) - } else { - let power3 = factors.get_power_of_three(); - let non_power_of_three = factors - .remove_factors(PrimeFactor { - value: 3, - count: power3, - }) - .unwrap(); - let power_of_three = PrimeFactors::compute(3usize.pow(power3)); - self.design_mixed_radix(power_of_three, non_power_of_three) - } + } else if let Some(butterfly_product) = self.design_butterfly_product(len) { + butterfly_product + } else if factors.has_factors_leq(MAX_RADIXN_FACTOR) { + self.design_radixn(factors) } else { let (left_factors, right_factors) = factors.partition_factors(); self.design_mixed_radix(left_factors, right_factors) } } + fn design_butterfly_product(&mut self, len: usize) -> Option> { + if len > 992 || len.is_power_of_two() { + return None; + } // 31*32 = 992. if we're above this size, don't bother. anddon't bother for powers of 2 because radix4 is fast + + let limit = (len as f64).sqrt().ceil() as usize + 1; + let butterflies = [ + 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 16, 17, 19, 23, 24, 27, 29, 31, 32, + ]; + + // search through our butterflies. if we find one that divides the length, see of the quotient is also a butterfly + // if it is, we have a butterfly product + // if there are multiple valid pairs, take the one with the smallest sum - we want the values to be as close together as possible + // ie 32 x 2, sum = 34, 16 x 4, sum = 20, 8 x 8, sum = 16, so even though 32,2 and 16x4 are valid, we want 8x8 + + let mut min_sum = usize::MAX; + let mut found_butterflies = None; + for left in butterflies.iter().take_while(|n| **n < limit) { + let right = len / left; + if left * right == len && butterflies.contains(&right) { + let sum = left + right; + if sum < min_sum { + min_sum = sum; + found_butterflies = Some((*left, right)) + } + } + } + + // if we found a valid pair of butterflies, construct a recipe for them + found_butterflies.map(|(left_len, right_len)| { + let left_fft = self.design_fft_for_len(left_len); + let right_fft = self.design_fft_for_len(right_len); + + if gcd(left_len, right_len) == 1 { + Arc::new(Recipe::GoodThomasAlgorithmSmall { + left_fft, + right_fft, + }) + } else { + Arc::new(Recipe::MixedRadixSmall { + left_fft, + right_fft, + }) + } + }) + } + fn design_mixed_radix( &mut self, left_factors: PrimeFactors, @@ -479,70 +504,105 @@ impl FftPlannerScalar { } } - fn design_radix3(&mut self, exponent: u32) -> Arc { - // plan a step of radix3 - let base_exponent = match exponent { - 0 => 0, - 1 => 1, - 2 => 2, - _ => 3, - }; - - let base_fft = self.design_fft_for_len(3usize.pow(base_exponent)); - Arc::new(Recipe::Radix3 { - k: exponent - base_exponent, - base_fft, - }) - } - - fn design_radix4(&mut self, factors: PrimeFactors) -> Arc { - // We can eventually relax this restriction -- it's not instrinsic to radix4, it's just that anything besides 2^n and 3*2^n hasn't been measured yet - assert!(factors.get_other_factors().is_empty() && factors.get_power_of_three() < 2); - + fn design_radixn(&mut self, factors: PrimeFactors) -> Arc { let p2 = factors.get_power_of_two(); - let base_len: usize = if factors.get_power_of_three() == 0 { - // pure power of 2 - match p2 { - // base cases. we shouldn't hit these but we might as well be ready for them - 0 => 1, - 1 => 2, - 2 => 4, - // main case: if len is a power of 4, use a base of 16, otherwise use a base of 8 - _ => { - if p2 % 2 == 1 { - 8 - } else { - 16 - } + let p3 = factors.get_power_of_three(); + let p5 = factors + .get_other_factors() + .iter() + .find_map(|f| if f.value == 5 { Some(f.count) } else { None }) // if we had rustc 1.62, we could use (f.value == 5).then_some(f.count) + .unwrap_or(0); + let p7 = factors + .get_other_factors() + .iter() + .find_map(|f| if f.value == 7 { Some(f.count) } else { None }) + .unwrap_or(0); + + let base_len: usize = if factors.has_factors_gt(MAX_RADIXN_FACTOR) { + // If we have factors larger than RadixN can handle, we *must* use the product of those factors as our base + factors.product_above(MAX_RADIXN_FACTOR) + } else if p7 == 0 && p5 == 0 && p3 < 2 { + // here we handle pure powers of 2 and 3 times a power of 2 - we want to hand these to radix4, so we need to consume the correct number of factors to leave us with 4^k + if p3 == 0 { + // pure power of 2 + assert!(p2 > 5); // butterflies should have caught this + if p2 % 2 == 1 { + 8 + } else { + 16 } - } - } else { - // we have a factor 3 that we're going to stick into the butterflies - match p2 { - // base cases. we shouldn't hit these but we might as well be ready for them - 0 => 3, - 1 => 6, - // main case: if len is 3*4^k, use a base of 12, otherwise use a base of 24 - _ => { - if p2 % 2 == 1 { - 24 - } else { - 12 - } + } else { + // 3 times a power of 2 + assert!(p2 > 3); // butterflies should have caught this + if p2 % 2 == 1 { + 24 + } else { + 12 } } + } else if p2 > 0 && p3 > 0 { + // we have a mixed bag of 2s and 3s + // todo: if we have way more 3s than 2s, benchmark using butterfly27 as the base + let excess_p2 = p2.saturating_sub(p3); + match excess_p2 { + 0 => 6, + 1 => 12, + _ => 24, + } + } else if p3 > 2 { + 27 + } else if p3 > 1 { + 9 + } else if p7 > 0 { + 7 + } else { + assert!(p5 > 0); + 5 }; // now that we know the base length, divide it out get what radix4 needs to compute - let cross_len = factors.get_product() / base_len; + let base_fft = self.design_fft_for_len(base_len); + let mut cross_len = factors.get_product() / base_len; + + // see if we can use radix4 + let cross_bits = cross_len.trailing_zeros(); + if cross_len.is_power_of_two() && cross_bits % 2 == 0 { + let k = cross_bits / 2; + return Arc::new(Recipe::Radix4 { k, base_fft }); + } + + // we weren't able to use radix4, so fall back to RadixN + // theoretically we could do this with the p2, p3, p5 etc values above, but our choice of base knocked them out of sync + let mut factors = Vec::new(); + while cross_len % 7 == 0 { + cross_len /= 7; + factors.push(RadixFactor::Factor7); + } + while cross_len % 6 == 0 { + cross_len /= 6; + factors.push(RadixFactor::Factor6); + } + while cross_len % 5 == 0 { + cross_len /= 5; + factors.push(RadixFactor::Factor5); + } + while cross_len % 3 == 0 { + cross_len /= 3; + factors.push(RadixFactor::Factor3); + } assert!(cross_len.is_power_of_two()); + // benchmarking suggests that we want to add the 4s *last*, i suspect because 4 is a better-than-usual value for the transpose let cross_bits = cross_len.trailing_zeros(); - assert!(cross_bits % 2 == 0); - let k = cross_bits / 2; + if cross_bits % 2 == 1 { + factors.push(RadixFactor::Factor2); + } + factors.extend(std::iter::repeat(RadixFactor::Factor4).take(cross_bits as usize / 2)); - let base_fft = self.design_fft_for_len(base_len); - Arc::new(Recipe::Radix4 { k, base_fft }) + Arc::new(Recipe::RadixN { + factors: factors.into_boxed_slice(), + base_fft, + }) } // Returns Some(instance) if we have a butterfly available for this size. Returns None if there is no butterfly available for this size @@ -608,13 +668,6 @@ impl FftPlannerScalar { mod unit_tests { use super::*; - fn is_mixedradix(plan: &Recipe) -> bool { - match plan { - &Recipe::MixedRadix { .. } => true, - _ => false, - } - } - fn is_mixedradixsmall(plan: &Recipe) -> bool { match plan { &Recipe::MixedRadixSmall { .. } => true, @@ -691,8 +744,8 @@ mod unit_tests { } #[test] - fn test_plan_scalar_mixedradix() { - // Products of several different primes should become MixedRadix + fn test_plan_scalar_radixn() { + // Products of several different small primes should become RadixN let mut planner = FftPlannerScalar::::new(); for pow2 in 2..5 { for pow3 in 2..5 { @@ -703,7 +756,17 @@ mod unit_tests { * 5usize.pow(pow5) * 7usize.pow(pow7); let plan = planner.design_fft_for_len(len); - assert!(is_mixedradix(&plan), "Expected MixedRadix, got {:?}", plan); + assert!( + matches!( + *plan, + Recipe::RadixN { + factors: _, + base_fft: _ + } + ), + "Expected MixedRadix, got {:?}", + plan + ); assert_eq!(plan.len(), len, "Recipe reports wrong length"); } } @@ -713,9 +776,9 @@ mod unit_tests { #[test] fn test_plan_scalar_mixedradixsmall() { - // Products of two "small" lengths < 31 that have a common divisor >1, and isn't a power of 2 should be MixedRadixSmall + // Products of two "small" butterflies < 31 that have a common divisor >1, and isn't a power of 2 should be MixedRadixSmall let mut planner = FftPlannerScalar::::new(); - for len in [5 * 20, 5 * 25].iter() { + for len in [12 * 3, 6 * 27].iter() { let plan = planner.design_fft_for_len(*len); assert!( is_mixedradixsmall(&plan), diff --git a/src/sse/mod.rs b/src/sse/mod.rs index b5be074d..df50b59c 100644 --- a/src/sse/mod.rs +++ b/src/sse/mod.rs @@ -17,10 +17,6 @@ use std::arch::x86_64::__m128d; use crate::FftNum; -pub use self::sse_butterflies::*; -pub use self::sse_prime_butterflies::*; -pub use self::sse_radix4::*; - use sse_vector::SseVector; pub trait SseNum: FftNum { diff --git a/tests/accuracy.rs b/tests/accuracy.rs index b20825cc..a268cff2 100644 --- a/tests/accuracy.rs +++ b/tests/accuracy.rs @@ -114,6 +114,7 @@ fn test_planned_fft_forward_f32() { let cache: ControlCache = ControlCache::new(TEST_MAX, direction); for len in 1..TEST_MAX { + println!("len: {len}"); let control = cache.plan_fft(len); assert_eq!(control.len(), len); assert_eq!(control.fft_direction(), direction);