Skip to content

Commit

Permalink
Finish implementation of precision as a feature
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed Mar 4, 2024
1 parent bbac109 commit 5cd34b0
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 134 deletions.
27 changes: 25 additions & 2 deletions examples/fftwrb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@ use std::{env, ptr::slice_from_raw_parts_mut, str::FromStr};

use fftw::{
array::AlignedVec,
plan::{C2CPlan, C2CPlan64},
types::{Flag, Sign},
};
use fftw::plan::C2CPlan;
#[cfg(feature = "single")]
use fftw::plan::C2CPlan32;
#[cfg(feature = "double")]
use fftw::plan::C2CPlan64;
use utilities::{gen_random_signal, rustfft::num_complex::Complex};

use phastft::Float;

fn benchmark_fftw(n: usize) {
let big_n = 1 << n;

let mut reals = vec![0.0; big_n];
let mut imags = vec![0.0; big_n];

gen_random_signal(&mut reals, &mut imags);
gen_random_signal::<Float>(&mut reals, &mut imags);
let mut nums = AlignedVec::new(big_n);
reals
.drain(..)
Expand All @@ -22,6 +28,8 @@ fn benchmark_fftw(n: usize) {
.for_each(|((re, im), z)| *z = Complex::new(re, im));

let now = std::time::Instant::now();

#[cfg(feature = "double")]
C2CPlan64::aligned(
&[big_n],
Sign::Backward,
Expand All @@ -34,6 +42,21 @@ fn benchmark_fftw(n: usize) {
&mut nums,
)
.unwrap();

#[cfg(feature = "single")]
C2CPlan32::aligned(
&[big_n],
Sign::Backward,
Flag::DESTROYINPUT | Flag::ESTIMATE,
)
.unwrap()
.c2c(
// SAFETY: See above comment.
unsafe { &mut *slice_from_raw_parts_mut(nums.as_mut_ptr(), big_n) },
&mut nums,
)
.unwrap();

let elapsed = now.elapsed().as_micros();
println!("{elapsed}");
}
Expand Down
4 changes: 2 additions & 2 deletions src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(crate) fn fft_chunk_n_simd(
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.array_chunks_mut::<8>()
.chunks_exact_mut(8)
.zip(reals_s1.chunks_exact_mut(8))
.zip(imags_s0.chunks_exact_mut(8))
.zip(imags_s1.chunks_exact_mut(8))
Expand Down Expand Up @@ -59,7 +59,7 @@ pub(crate) fn fft_chunk_n_simd(
) {
const VECTOR_SIZE: usize = 16;
let chunk_size = dist << 1;
assert!(chunk_size >= 16);
assert!(chunk_size >= 32);

reals
.chunks_exact_mut(chunk_size)
Expand Down
23 changes: 4 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#![warn(clippy::perf)]
#![forbid(unsafe_code)]
#![feature(portable_simd)]
#![feature(array_chunks)]

use crate::cobra::cobra_apply;
use crate::kernels::{fft_chunk_2, fft_chunk_4, fft_chunk_n, fft_chunk_n_simd};
Expand All @@ -20,11 +19,12 @@ mod kernels;
pub mod options;
pub mod planner;
mod twiddles;
mod utils;

/// Redefine `Float` as `f64` for double precision data
#[cfg(feature = "double")]
pub type Float = f64;

/// Redefine `Float` as `f32` for single precision data
#[cfg(feature = "single")]
pub type Float = f32;

Expand Down Expand Up @@ -91,7 +91,7 @@ pub fn fft_with_opts_and_plan(
if t < n - 1 {
filter_twiddles(twiddles_re, twiddles_im);
}
if chunk_size >= 16 {
if chunk_size >= 32 {
fft_chunk_n_simd(reals, imags, twiddles_re, twiddles_im, dist);
} else {
fft_chunk_n(reals, imags, twiddles_re, twiddles_im, dist);
Expand All @@ -118,29 +118,14 @@ pub fn fft_with_opts_and_plan(
mod tests {
use std::ops::Range;

#[cfg(feature = "single")]
use utilities::assert_f32_closeness;
#[cfg(feature = "double")]
use utilities::assert_f64_closeness;
use utilities::assert_float_closeness;
use utilities::rustfft::FftPlanner;
use utilities::rustfft::num_complex::Complex;

use crate::planner::Direction;

Check warning on line 125 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Test

the item `Direction` is imported redundantly

Check warning on line 125 in src/lib.rs

View workflow job for this annotation

GitHub Actions / Test

the item `Direction` is imported redundantly

use super::*;

fn assert_float_closeness(actual: Float, expected: Float, epsilon: Float) {
#[cfg(feature = "double")]
{
assert_f64_closeness(actual, expected, epsilon)
}

#[cfg(feature = "single")]
{
assert_f32_closeness(actual, expected, epsilon)
}
}

#[should_panic]
#[test]
fn non_power_of_two_fft() {
Expand Down
22 changes: 3 additions & 19 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,10 @@ impl Planner {

#[cfg(test)]
mod tests {
#[cfg(feature = "single")]
use utilities::assert_f32_closeness;
#[cfg(feature = "double")]
use utilities::assert_f64_closeness;
use utilities::assert_float_closeness;

use crate::Float;
use crate::planner::{Direction, Planner};

fn assert_float_closeness(actual: Float, expected: Float, epsilon: Float) {
#[cfg(feature = "double")]
{
assert_f64_closeness(actual, expected, epsilon)
}

#[cfg(feature = "single")]
{
assert_f32_closeness(actual, expected, epsilon)
}
}

#[test]
fn no_twiddles() {
for num_points in [2, 4] {
Expand Down Expand Up @@ -117,8 +101,8 @@ mod tests {
.for_each(|(((a, b), c), d)| {
let temp_re = a * c - b * d;
let temp_im = a * d + b * c;
assert_float_closeness(temp_re, 1.0, 1e-6);
assert_float_closeness(temp_im, 0.0, 1e-6);
assert_float_closeness(temp_re, 1.0, 1e-3);
assert_float_closeness(temp_im, 0.0, 1e-3);
});
}
}
Expand Down
70 changes: 25 additions & 45 deletions src/twiddles.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#[cfg(feature = "single")]
use std::simd::f32x8;
#[cfg(feature = "double")]
use std::simd::f64x8;
use std::simd::Simd;

use crate::Float;
use crate::planner::Direction;

const PI: Float = std::f64::consts::PI as Float;
#[cfg(feature = "single")]
const PI: Float = std::f32::consts::PI;

#[cfg(feature = "double")]
const PI: Float = std::f64::consts::PI;

pub(crate) struct Twiddles {
st: Float,
Expand Down Expand Up @@ -109,21 +111,10 @@ pub(crate) fn generate_twiddles_simd(
};

let apply_symmetry_re = |input: &[Float], output: &mut [Float]| {
#[cfg(feature = "double")]
{
let first_re = f64x8::from_slice(input);
let minus_one = f64x8::splat(-1.0);
let negated = (first_re * minus_one).reverse();
output.copy_from_slice(negated.as_array());
}

#[cfg(feature = "single")]
{
let first_re = f32x8::from_slice(input);
let minus_one = f32x8::splat(-1.0);
let negated = (first_re * minus_one).reverse();
output.copy_from_slice(negated.as_array());
}
let first_re = Simd::<Float, 8>::from_slice(input);
let minus_one = Simd::<Float, 8>::splat(-1.0);
let negated = (first_re * minus_one).reverse();
output.copy_from_slice(negated.as_array());
};

let apply_symmetry_im = |input: &[Float], output: &mut [Float]| {
Expand Down Expand Up @@ -207,26 +198,15 @@ pub(crate) fn filter_twiddles(twiddles_re: &mut Vec<Float>, twiddles_im: &mut Ve

#[cfg(test)]
mod tests {
#[cfg(feature = "single")]
use utilities::assert_f32_closeness;
#[cfg(feature = "double")]
use utilities::assert_f64_closeness;
use utilities::assert_float_closeness;

use super::*;

const FRAC_1_SQRT_2: Float = std::f64::consts::FRAC_1_SQRT_2 as Float;

fn assert_float_closeness(actual: Float, expected: Float, epsilon: Float) {
#[cfg(feature = "double")]
{
assert_f64_closeness(actual, expected, epsilon)
}
#[cfg(feature = "double")]
const FRAC_1_SQRT_2: Float = std::f64::consts::FRAC_1_SQRT_2;

#[cfg(feature = "single")]
{
assert_f32_closeness(actual, expected, epsilon)
}
}
#[cfg(feature = "single")]
const FRAC_1_SQRT_2: Float = std::f32::consts::FRAC_1_SQRT_2;

#[test]
fn twiddles_4() {
Expand All @@ -235,23 +215,23 @@ mod tests {

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_float_closeness(w_re, 1.0, 1e-10);
assert_float_closeness(w_im, 0.0, 1e-10);
assert_float_closeness(w_re, 1.0, 1e-6);
assert_float_closeness(w_im, 0.0, 1e-6);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_re, FRAC_1_SQRT_2, 1e-6);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-6);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_float_closeness(w_re, 0.0, 1e-10);
assert_float_closeness(w_im, -1.0, 1e-10);
assert_float_closeness(w_re, 0.0, 1e-6);
assert_float_closeness(w_im, -1.0, 1e-6);

let (w_re, w_im) = twiddle_iter.next().unwrap();
println!("{w_re} {w_im}");
assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-10);
assert_float_closeness(w_re, -FRAC_1_SQRT_2, 1e-6);
assert_float_closeness(w_im, -FRAC_1_SQRT_2, 1e-6);
}

#[test]
Expand All @@ -266,14 +246,14 @@ mod tests {
.iter()
.zip(twiddles_re_ref.iter())
.for_each(|(simd, reference)| {
assert_float_closeness(*simd, *reference, 1e-10);
assert_float_closeness(*simd, *reference, 1e-3);
});

twiddles_im
.iter()
.zip(twiddles_im_ref.iter())
.for_each(|(simd, reference)| {
assert_float_closeness(*simd, *reference, 1e-10);
assert_float_closeness(*simd, *reference, 1e-3);
});
}
}
Expand Down
Loading

0 comments on commit 5cd34b0

Please sign in to comment.