Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RadixN: Like Radix4, but supports size 2,3,4,5,6, and 7 cross-FFTs all in the same instance #132

Merged
merged 6 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/algorithm/dft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ impl<T: FftNum> Dft<T> {
}
}

fn inplace_scratch_len(&self) -> usize {
self.len()
}
fn outofplace_scratch_len(&self) -> usize {
0
}

fn perform_fft_out_of_place(
&self,
signal: &[Complex<T>],
Expand Down
2 changes: 2 additions & 0 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod mixed_radix;
mod raders_algorithm;
mod radix3;
mod radix4;
mod radixn;

/// Hardcoded size-specfic FFT algorithms
pub mod butterflies;
Expand All @@ -16,3 +17,4 @@ pub use self::mixed_radix::{MixedRadix, MixedRadixSmall};
pub use self::raders_algorithm::RadersAlgorithm;
pub use self::radix3::Radix3;
pub use self::radix4::Radix4;
pub use self::radixn::RadixN;
89 changes: 40 additions & 49 deletions src/algorithm/radix3.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9};
use crate::algorithm::radixn::butterfly_3;
use crate::array_utils::{self, bitreversed_transpose, compute_logarithm};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
Expand Down Expand Up @@ -32,6 +32,8 @@ pub struct Radix3<T> {

len: usize,
direction: FftDirection,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
}

impl<T: FftNum> Radix3<T> {
Expand Down Expand Up @@ -68,20 +70,32 @@ impl<T: FftNum> Radix3<T> {
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
const ROW_COUNT: usize = 3;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut cross_fft_len = base_len;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;
while cross_fft_len < len {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
cross_fft_len *= ROW_COUNT;
}

let base_inplace_scratch = base_fft.get_inplace_scratch_len();
let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
cross_fft_len + base_inplace_scratch
} else {
cross_fft_len
};
let outofplace_scratch_len = if base_inplace_scratch > len {
base_inplace_scratch
} else {
0
};

Self {
twiddles: twiddle_factors.into_boxed_slice(),
butterfly3: Butterfly3::new(direction),
Expand All @@ -91,14 +105,24 @@ impl<T: FftNum> Radix3<T> {

len,
direction,

inplace_scratch_len,
outofplace_scratch_len,
}
}

fn inplace_scratch_len(&self) -> usize {
self.inplace_scratch_len
}
fn outofplace_scratch_len(&self) -> usize {
self.outofplace_scratch_len
}

fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
// copy the data into the output vector
if self.len() == self.base_len {
Expand All @@ -108,63 +132,30 @@ impl<T: FftNum> Radix3<T> {
}

// Base-level FFTs
self.base_fft.process_with_scratch(output, &mut []);
let base_scratch = if scratch.len() > 0 { scratch } else { input };
self.base_fft.process_with_scratch(output, base_scratch);

// cross-FFTs
const ROW_COUNT: usize = 3;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut cross_fft_len = self.base_len;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_3(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_columns,
&self.butterfly3,
)
}
while cross_fft_len < output.len() {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for data in output.chunks_exact_mut(cross_fft_len) {
unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) }
}

// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len);

unsafe fn butterfly_3<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
num_ffts: usize,
butterfly3: &Butterfly3<T>,
) {
let mut idx = 0usize;
let mut tw_idx = 0usize;
let mut scratch = [Zero::zero(); 3];
for _ in 0..num_ffts {
scratch[0] = *data.get_unchecked(idx);
scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];

butterfly3.perform_fft_butterfly(&mut scratch);

*data.get_unchecked_mut(idx) = scratch[0];
*data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1];
*data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2];

tw_idx += 2;
idx += 1;
}
}

#[cfg(test)]
mod unit_tests {
use super::*;
Expand Down
95 changes: 42 additions & 53 deletions src/algorithm/radix4.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::sync::Arc;

use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{
Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8,
};
use crate::algorithm::radixn::butterfly_4;
use crate::array_utils::{self, bitreversed_transpose};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
Expand Down Expand Up @@ -33,6 +33,8 @@ pub struct Radix4<T> {

len: usize,
direction: FftDirection,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
}

impl<T: FftNum> Radix4<T> {
Expand Down Expand Up @@ -75,20 +77,32 @@ impl<T: FftNum> Radix4<T> {
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
const ROW_COUNT: usize = 4;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut cross_fft_len = base_len;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;
while cross_fft_len < len {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
cross_fft_len *= ROW_COUNT;
}

let base_inplace_scratch = base_fft.get_inplace_scratch_len();
let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
cross_fft_len + base_inplace_scratch
} else {
cross_fft_len
};
let outofplace_scratch_len = if base_inplace_scratch > len {
base_inplace_scratch
} else {
0
};

Self {
twiddles: twiddle_factors.into_boxed_slice(),

Expand All @@ -97,14 +111,24 @@ impl<T: FftNum> Radix4<T> {

len,
direction,

inplace_scratch_len,
outofplace_scratch_len,
}
}

fn inplace_scratch_len(&self) -> usize {
self.inplace_scratch_len
}
fn outofplace_scratch_len(&self) -> usize {
self.outofplace_scratch_len
}

fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
// copy the data into the output vector
if self.len() == self.base_len {
Expand All @@ -114,67 +138,32 @@ impl<T: FftNum> Radix4<T> {
}

// Base-level FFTs
self.base_fft.process_with_scratch(output, &mut []);
let base_scratch = if scratch.len() > 0 { scratch } else { input };
self.base_fft.process_with_scratch(output, base_scratch);

// cross-FFTs
const ROW_COUNT: usize = 4;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut cross_fft_len = self.base_len;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_4(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_columns,
self.direction,
)
}
let butterfly4 = Butterfly4::new(self.direction);

while cross_fft_len < output.len() {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for data in output.chunks_exact_mut(cross_fft_len) {
unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) }
}

// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);

unsafe fn butterfly_4<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
num_ffts: usize,
direction: FftDirection,
) {
let butterfly4 = Butterfly4::new(direction);

let mut idx = 0usize;
let mut tw_idx = 0usize;
let mut scratch = [Zero::zero(); 4];
for _ in 0..num_ffts {
scratch[0] = *data.get_unchecked(idx);
scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];
scratch[3] = *data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2];

butterfly4.perform_fft_butterfly(&mut scratch);

*data.get_unchecked_mut(idx) = scratch[0];
*data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1];
*data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2];
*data.get_unchecked_mut(idx + 3 * num_ffts) = scratch[3];

tw_idx += 3;
idx += 1;
}
}

#[cfg(test)]
mod unit_tests {
use super::*;
Expand Down
Loading
Loading