Skip to content

Commit

Permalink
Rewrote radix4 to be simpler, more readable, slightly faster
Browse files Browse the repository at this point in the history
  • Loading branch information
ejmahler committed Dec 27, 2020
1 parent b9b9a1d commit 6ab56f9
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 96 deletions.
62 changes: 61 additions & 1 deletion src/algorithm/butterflies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,66 @@ macro_rules! boilerplate_fft_butterfly {
};
}

pub struct Butterfly1<T> {
inverse: bool,
_phantom: std::marker::PhantomData<T>,
}
impl<T: FFTnum> Butterfly1<T> {
#[inline(always)]
pub fn new(inverse: bool) -> Self {
Self {
inverse,
_phantom: std::marker::PhantomData,
}
}
}
impl<T: FFTnum> Fft<T> for Butterfly1<T> {
fn process_with_scratch(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
output.copy_from_slice(&input);
}

fn process_inplace_with_scratch(
&self,
_buffer: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
}

fn process_multi(
&self,
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
output.copy_from_slice(&input);
}

fn process_inplace_multi(&self, _buffer: &mut [Complex<T>], _scratch: &mut [Complex<T>]) {}

fn get_inplace_scratch_len(&self) -> usize {
0
}

fn get_out_of_place_scratch_len(&self) -> usize {
0
}
}
impl<T> Length for Butterfly1<T> {
fn len(&self) -> usize {
1
}
}
impl<T> IsInverse for Butterfly1<T> {
fn is_inverse(&self) -> bool {
self.inverse
}
}

pub struct Butterfly2<T> {
inverse: bool,
_phantom: std::marker::PhantomData<T>,
Expand Down Expand Up @@ -222,7 +282,7 @@ impl<T: FFTnum> Butterfly4<T> {
}
}
#[inline(always)]
unsafe fn perform_fft_contiguous(
pub(crate) unsafe fn perform_fft_contiguous(
&self,
input: RawSlice<Complex<T>>,
output: RawSliceMut<Complex<T>>,
Expand Down
179 changes: 84 additions & 95 deletions src/algorithm/radix4.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use std::sync::Arc;

use num_complex::Complex;
use num_traits::Zero;

use crate::common::FFTnum;
use crate::{
array_utils::{RawSlice, RawSliceMut},
common::FFTnum,
};

use crate::algorithm::butterflies::{Butterfly16, Butterfly2, Butterfly4, Butterfly8};
use crate::algorithm::butterflies::{Butterfly1, Butterfly16, Butterfly2, Butterfly4, Butterfly8};
use crate::{Fft, IsInverse, Length};

/// FFT algorithm optimized for power-of-two sizes
Expand All @@ -24,8 +29,10 @@ use crate::{Fft, IsInverse, Length};
pub struct Radix4<T> {
twiddles: Box<[Complex<T>]>,
butterfly8: Butterfly8<T>,
butterfly16: Butterfly16<T>,

base_fft: Arc<dyn Fft<T>>,
base_len: usize,

len: usize,
inverse: bool,
}
Expand All @@ -39,17 +46,26 @@ impl<T: FFTnum> Radix4<T> {
len
);

// precompute the twiddle factors this algorithm will use.
// we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom and going up
// figure out which base length we're going to use
let num_bits = len.trailing_zeros();
let mut twiddle_stride = if num_bits % 2 == 0 {
len / 64
} else {
len / 32
let (base_len, base_fft) = match num_bits {
0 => (len, Arc::new(Butterfly1::new(inverse)) as Arc<dyn Fft<T>>),
1 => (len, Arc::new(Butterfly2::new(inverse)) as Arc<dyn Fft<T>>),
2 => (len, Arc::new(Butterfly4::new(inverse)) as Arc<dyn Fft<T>>),
_ => {
if num_bits % 2 == 1 {
(8, Arc::new(Butterfly8::new(inverse)) as Arc<dyn Fft<T>>)
} else {
(16, Arc::new(Butterfly16::new(inverse)) as Arc<dyn Fft<T>>)
}
}
};

// precompute the twiddle factors this algorithm will use.
// we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
let mut twiddle_stride = len / (base_len * 4);
let mut twiddle_factors = Vec::with_capacity(len * 2);
while twiddle_stride > 0 {
let num_rows = len / (twiddle_stride * 4);
Expand All @@ -64,8 +80,10 @@ impl<T: FFTnum> Radix4<T> {

Self {
twiddles: twiddle_factors.into_boxed_slice(),
butterfly8: Butterfly8::new(inverse),
butterfly16: Butterfly16::new(inverse),

base_fft,
base_len,

len,
inverse,
}
Expand All @@ -77,58 +95,35 @@ impl<T: FFTnum> Radix4<T> {
spectrum: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
match self.len() {
0 | 1 => spectrum.copy_from_slice(signal),
2 => {
spectrum.copy_from_slice(signal);
unsafe { Butterfly2::new(self.inverse).perform_fft_butterfly(spectrum) }
}
4 => {
spectrum.copy_from_slice(signal);
unsafe { Butterfly4::new(self.inverse).perform_fft_butterfly(spectrum) }
}
_ => {
// copy the data into the spectrum vector
prepare_radix4(signal.len(), signal, spectrum, 1);
// copy the data into the spectrum vector
prepare_radix4(signal.len(), self.base_len, signal, spectrum, 1);

// perform the butterflies. the butterfly size depends on the input size
let num_bits = signal.len().trailing_zeros();
let mut current_size = if num_bits % 2 == 0 {
self.butterfly16.process_inplace_multi(spectrum, &mut []);
// Base-level FFTs
self.base_fft.process_inplace_multi(spectrum, &mut []);

// for the cross-ffts we want to to start off with a size of 64 (16 * 4)
64
} else {
self.butterfly8.process_inplace_multi(spectrum, &mut []);

// for the cross-ffts we want to to start off with a size of 32 (8 * 4)
32
};

let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

// now, perform all the cross-FFTs, one "layer" at a time
while current_size <= signal.len() {
let num_rows = signal.len() / current_size;

for i in 0..num_rows {
unsafe {
butterfly_4(
&mut spectrum[i * current_size..],
layer_twiddles,
current_size / 4,
self.inverse,
)
}
}

//skip past all the twiddle factors used in this layer
let twiddle_offset = (current_size * 3) / 4;
layer_twiddles = &layer_twiddles[twiddle_offset..];

current_size *= 4;
// cross-FFTs
let mut current_size = self.base_len * 4;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while current_size <= signal.len() {
let num_rows = signal.len() / current_size;

for i in 0..num_rows {
unsafe {
butterfly_4(
&mut spectrum[i * current_size..],
layer_twiddles,
current_size / 4,
self.inverse,
)
}
}

//skip past all the twiddle factors used in this layer
let twiddle_offset = (current_size * 3) / 4;
layer_twiddles = &layer_twiddles[twiddle_offset..];

current_size *= 4;
}
}
}
Expand All @@ -138,25 +133,26 @@ boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);
// was almost an order of magnitude faster at setting up
fn prepare_radix4<T: FFTnum>(
size: usize,
base_len: usize,
signal: &[Complex<T>],
spectrum: &mut [Complex<T>],
stride: usize,
) {
match size {
2 | 4 | 8 | 16 => unsafe {
if size == base_len {
unsafe {
for i in 0..size {
*spectrum.get_unchecked_mut(i) = *signal.get_unchecked(i * stride);
}
},
_ => {
for i in 0..4 {
prepare_radix4(
size / 4,
&signal[i * stride..],
&mut spectrum[i * (size / 4)..],
stride * 4,
);
}
}
} else {
for i in 0..4 {
prepare_radix4(
size / 4,
base_len,
&signal[i * stride..],
&mut spectrum[i * (size / 4)..],
stride * 4,
);
}
}
}
Expand All @@ -167,30 +163,23 @@ unsafe fn butterfly_4<T: FFTnum>(
num_ffts: usize,
inverse: bool,
) {
let butterfly4 = Butterfly4::new(inverse);

let mut idx = 0usize;
let mut tw_idx = 0usize;
let mut scratch: [Complex<T>; 6] = [Zero::zero(); 6];
let mut scratch = [Zero::zero(); 4];
for _ in 0..num_ffts {
scratch[0] = data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[1] = data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];
scratch[2] = data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2];
scratch[5] = data.get_unchecked(idx) - scratch[1];
*data.get_unchecked_mut(idx) = data.get_unchecked(idx) + scratch[1];
scratch[3] = scratch[0] + scratch[2];
scratch[4] = scratch[0] - scratch[2];
*data.get_unchecked_mut(idx + 2 * num_ffts) = data.get_unchecked(idx) - scratch[3];
*data.get_unchecked_mut(idx) = data.get_unchecked(idx) + scratch[3];
if inverse {
data.get_unchecked_mut(idx + num_ffts).re = scratch[5].re - scratch[4].im;
data.get_unchecked_mut(idx + num_ffts).im = scratch[5].im + scratch[4].re;
data.get_unchecked_mut(idx + 3 * num_ffts).re = scratch[5].re + scratch[4].im;
data.get_unchecked_mut(idx + 3 * num_ffts).im = scratch[5].im - scratch[4].re;
} else {
data.get_unchecked_mut(idx + num_ffts).re = scratch[5].re + scratch[4].im;
data.get_unchecked_mut(idx + num_ffts).im = scratch[5].im - scratch[4].re;
data.get_unchecked_mut(idx + 3 * num_ffts).re = scratch[5].re - scratch[4].im;
data.get_unchecked_mut(idx + 3 * num_ffts).im = scratch[5].im + scratch[4].re;
}
scratch[0] = *data.get_unchecked(idx);
scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];
scratch[3] = *data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2];

butterfly4.perform_fft_contiguous(RawSlice::new(&scratch), RawSliceMut::new(&mut scratch));

*data.get_unchecked_mut(idx) = scratch[0];
*data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1];
*data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2];
*data.get_unchecked_mut(idx + 3 * num_ffts) = scratch[3];

tw_idx += 3;
idx += 1;
Expand Down

0 comments on commit 6ab56f9

Please sign in to comment.