Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Radix4 cleanup #129

Merged
merged 2 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ pub use self::good_thomas_algorithm::{GoodThomasAlgorithm, GoodThomasAlgorithmSm
pub use self::mixed_radix::{MixedRadix, MixedRadixSmall};
pub use self::raders_algorithm::RadersAlgorithm;
pub use self::radix3::Radix3;
pub use self::radix4::{bitreversed_transpose, Radix4};
pub use self::radix4::Radix4;
135 changes: 27 additions & 108 deletions src/algorithm/radix3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9};
use crate::array_utils;
use crate::array_utils::{self, bitreversed_transpose, compute_logarithm};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
use crate::{Direction, Fft, Length};
Expand Down Expand Up @@ -38,7 +38,7 @@ impl<T: FftNum> Radix3<T> {
/// Preallocates necessary arrays and precomputes necessary data to efficiently compute the power-of-three FFT
pub fn new(len: usize, direction: FftDirection) -> Self {
// Compute the total power of 3 for this length. IE, len = 3^exponent
let exponent = compute_logarithm(len, 3).unwrap_or_else(|| {
let exponent = compute_logarithm::<3>(len).unwrap_or_else(|| {
panic!(
"Radix3 algorithm requires a power-of-three input size. Got {}",
len
Expand All @@ -57,17 +57,19 @@ impl<T: FftNum> Radix3<T> {
// we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=3 and height=len/3
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
let mut twiddle_stride = len / (base_len * 3);
const ROW_COUNT: usize = 3;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while twiddle_stride > 0 {
let num_rows = len / (twiddle_stride * 3);
for i in 0..num_rows {
for k in 1..3 {
let twiddle = twiddles::compute_twiddle(i * k * twiddle_stride, len, direction);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
twiddle_stride /= 3;
cross_fft_len *= ROW_COUNT;
}

Self {
Expand All @@ -84,133 +86,50 @@ impl<T: FftNum> Radix3<T> {

fn perform_fft_out_of_place(
&self,
signal: &[Complex<T>],
spectrum: &mut [Complex<T>],
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
// copy the data into the spectrum vector
// copy the data into the output vector
if self.len() == self.base_len {
spectrum.copy_from_slice(signal);
output.copy_from_slice(input);
} else {
bitreversed_transpose(self.base_len, signal, spectrum);
bitreversed_transpose::<Complex<T>, 3>(self.base_len, input, output);
}

// Base-level FFTs
self.base_fft.process_with_scratch(spectrum, &mut []);
self.base_fft.process_with_scratch(output, &mut []);

// cross-FFTs
let mut current_size = self.base_len * 3;
const ROW_COUNT: usize = 3;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while current_size <= signal.len() {
let num_rows = signal.len() / current_size;
while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_3(
&mut spectrum[i * current_size..],
&mut output[i * cross_fft_len..],
layer_twiddles,
current_size / 3,
num_columns,
&self.butterfly3,
)
}
}

//skip past all the twiddle factors used in this layer
let twiddle_offset = (current_size * 2) / 3;
// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

current_size *= 3;
cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len);

// Preparing for radix 3 is similar to a transpose, where the column index is bit reversed.
// Use a lookup table to avoid repeating the slow bit reverse operations.
// Unrolling the outer loop by a factor 4 helps speed things up.
pub fn bitreversed_transpose<T: Copy>(height: usize, input: &[T], output: &mut [T]) {
let width = input.len() / height;
let third_width = width / 3;

let rev_digits = compute_logarithm(width, 3).unwrap();

// Let's make sure the arguments are ok
assert!(input.len() == output.len());
for x in 0..third_width {
let x0 = 3 * x;
let x1 = 3 * x + 1;
let x2 = 3 * x + 2;

let x_rev = [
reverse_bits(x0, rev_digits),
reverse_bits(x1, rev_digits),
reverse_bits(x2, rev_digits),
];

// Assert that the the bit reversed indices will not exceed the length of the output.
// The highest index the loop reaches is: (x_rev[n] + 1)*height - 1
// The last element of the data is at index: width*height - 1
// Thus it is sufficient to assert that x_rev[n]<width.
assert!(x_rev[0] < width && x_rev[1] < width && x_rev[2] < width);

for y in 0..height {
let input_index0 = x0 + y * width;
let input_index1 = x1 + y * width;
let input_index2 = x2 + y * width;
let output_index0 = y + x_rev[0] * height;
let output_index1 = y + x_rev[1] * height;
let output_index2 = y + x_rev[2] * height;

unsafe {
let temp0 = *input.get_unchecked(input_index0);
let temp1 = *input.get_unchecked(input_index1);
let temp2 = *input.get_unchecked(input_index2);

*output.get_unchecked_mut(output_index0) = temp0;
*output.get_unchecked_mut(output_index1) = temp1;
*output.get_unchecked_mut(output_index2) = temp2;
}
}
}
}

// computes `n` such that `base ^ n == value`. Returns `None` if `value` is not a perfect power of `base`, otherwise returns `Some(n)`
fn compute_logarithm(value: usize, base: usize) -> Option<usize> {
if value == 0 || base == 0 {
return None;
}

let mut current_exponent = 0;
let mut current_value = value;

while current_value % base == 0 {
current_exponent += 1;
current_value /= base;
}

if current_value == 1 {
Some(current_exponent)
} else {
None
}
}

// Sort of like reversing bits in radix4. We're not actually reversing bits, but the algorithm is exactly the same.
// Radix4's bit reversal does divisions by 4, multiplications by 4, and modulo 4 - all of which are easily represented by bit manipulation.
// As a result, it can be thought of as a bit reversal. But really, the "bit reversal"-ness of it is a special case of a more general "remainder reversal"
// IE, it's repeatedly taking the remainder of dividing by N, and building a new number where those remainders are reversed.
// So this algorithm does all the things that bit reversal does, but replaces the multiplications by 4 with multiplications by 3, etc, and ends up with the same conceptual result as a bit reversal.
pub fn reverse_bits(value: usize, reversal_iters: usize) -> usize {
let mut result: usize = 0;
let mut value = value;
for _ in 0..reversal_iters {
result = (result * 3) + (value % 3);
value /= 3;
}
result
}

unsafe fn butterfly_3<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
Expand Down
115 changes: 26 additions & 89 deletions src/algorithm/radix4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{Butterfly1, Butterfly16, Butterfly2, Butterfly4, Butterfly8};
use crate::array_utils;
use crate::array_utils::{self, bitreversed_transpose};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
use crate::{Direction, Fft, Length};
Expand Down Expand Up @@ -61,17 +61,19 @@ impl<T: FftNum> Radix4<T> {
// we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
let mut twiddle_stride = len / (base_len * 4);
const ROW_COUNT: usize = 4;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while twiddle_stride > 0 {
let num_rows = len / (twiddle_stride * 4);
for i in 0..num_rows {
for k in 1..4 {
let twiddle = twiddles::compute_twiddle(i * k * twiddle_stride, len, direction);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
twiddle_stride /= 4;
cross_fft_len *= ROW_COUNT;
}

Self {
Expand All @@ -87,115 +89,50 @@ impl<T: FftNum> Radix4<T> {

fn perform_fft_out_of_place(
&self,
signal: &[Complex<T>],
spectrum: &mut [Complex<T>],
input: &[Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
) {
// copy the data into the spectrum vector
// copy the data into the output vector
if self.len() == self.base_len {
spectrum.copy_from_slice(signal);
output.copy_from_slice(input);
} else {
bitreversed_transpose(self.base_len, signal, spectrum);
bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
}

// Base-level FFTs
self.base_fft.process_with_scratch(spectrum, &mut []);
self.base_fft.process_with_scratch(output, &mut []);

// cross-FFTs
let mut current_size = self.base_len * 4;
const ROW_COUNT: usize = 4;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while current_size <= signal.len() {
let num_rows = signal.len() / current_size;
while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_4(
&mut spectrum[i * current_size..],
&mut output[i * cross_fft_len..],
layer_twiddles,
current_size / 4,
num_columns,
self.direction,
)
}
}

//skip past all the twiddle factors used in this layer
let twiddle_offset = (current_size * 3) / 4;
// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

current_size *= 4;
cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);

// Preparing for radix 4 is similar to a transpose, where the column index is bit reversed.
// Use a lookup table to avoid repeating the slow bit reverse operations.
// Unrolling the outer loop by a factor 4 helps speed things up.
pub fn bitreversed_transpose<T: Copy>(height: usize, input: &[T], output: &mut [T]) {
let width = input.len() / height;
let quarter_width = width / 4;

let rev_digits = (width.trailing_zeros() / 2) as usize;

// Let's make sure the arguments are ok
assert!(input.len() == output.len());
for x in 0..quarter_width {
let x0 = 4 * x;
let x1 = 4 * x + 1;
let x2 = 4 * x + 2;
let x3 = 4 * x + 3;

let x_rev = [
reverse_bits(x0, rev_digits),
reverse_bits(x1, rev_digits),
reverse_bits(x2, rev_digits),
reverse_bits(x3, rev_digits),
];

// Assert that the the bit reversed indices will not exceed the length of the output.
// The highest index the loop reaches is: (x_rev[n] + 1)*height - 1
// The last element of the data is at index: width*height - 1
// Thus it is sufficient to assert that x_rev[n]<width.
assert!(x_rev[0] < width && x_rev[1] < width && x_rev[2] < width && x_rev[3] < width);

for y in 0..height {
let input_index0 = x0 + y * width;
let input_index1 = x1 + y * width;
let input_index2 = x2 + y * width;
let input_index3 = x3 + y * width;
let output_index0 = y + x_rev[0] * height;
let output_index1 = y + x_rev[1] * height;
let output_index2 = y + x_rev[2] * height;
let output_index3 = y + x_rev[3] * height;

unsafe {
let temp0 = *input.get_unchecked(input_index0);
let temp1 = *input.get_unchecked(input_index1);
let temp2 = *input.get_unchecked(input_index2);
let temp3 = *input.get_unchecked(input_index3);

*output.get_unchecked_mut(output_index0) = temp0;
*output.get_unchecked_mut(output_index1) = temp1;
*output.get_unchecked_mut(output_index2) = temp2;
*output.get_unchecked_mut(output_index3) = temp3;
}
}
}
}

// Reverse bits of value, in pairs.
// For 8 bits: abcdefgh -> ghefcdab
pub fn reverse_bits(value: usize, bitpairs: usize) -> usize {
let mut result: usize = 0;
let mut value = value;
for _ in 0..bitpairs {
result = (result << 2) + (value & 0x03);
value = value >> 2;
}
result
}

unsafe fn butterfly_4<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
Expand Down
Loading
Loading