diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 0000000..e559401 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,55 @@ +on: [pull_request] + +name: CI + +jobs: + check: + name: Check+Test + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + - beta + - nightly + - 1.37 + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - name: Run cargo check + uses: actions-rs/cargo@v1 + with: + command: check + + - name: Run cargo test + uses: actions-rs/cargo@v1 + with: + command: test + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + components: rustfmt + + - name: Run cargo fmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: -- --check diff --git a/Cargo.toml b/Cargo.toml index 5b335c1..a387aaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "realfft" -version = "0.4.0" +version = "1.0.0" authors = ["HEnquist "] edition = "2018" description = "Real-to-complex FFT and complex-to-real iFFT for Rust" @@ -17,6 +17,7 @@ rustfft = "5.0.0" [dev-dependencies] criterion = "0.3" +rand = "0.8.1" [[bench]] name = "realfft" diff --git a/README.md b/README.md index 0fe7245..1409f7d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,13 @@ -# RealFFT: Real-to-complex FFT and complex-to-real iFFT based on RustFFT +# realfft -This library is a wrapper for RustFFT that enables faster computations when the input data is real. -It packs a 2N long real vector into an N long complex vector, which is transformed using a standard FFT. +## RealFFT: Real-to-complex FFT and complex-to-real iFFT based on RustFFT + +This library is a wrapper for RustFFT that enables performing FFT of real-valued data. +The API is designed to be as similar as possible to RustFFT. + +Using this library instead of RustFFT directly avoids the need of converting real-valued data to complex before performing a FFT. +If the length is even, it also enables faster computations by using a complex FFT of half the length. +It then packs a 2N long real vector into an N long complex vector, which is transformed using a standard FFT. It then post-processes the result to give only the first half of the complex spectrum, as an N+1 long complex vector. The iFFT goes through the same steps backwards, to transform an N+1 long complex spectrum to a 2N long real result. @@ -9,86 +15,114 @@ The iFFT goes through the same steps backwards, to transform an N+1 long complex The speed increase compared to just converting the input to a 2N long complex vector and using a 2N long FFT depends on the length f the input data. The largest improvements are for long FFTs and for lengths over around 1000 elements there is an improvement of about a factor 2. -The difference shrinks for shorter lengths, and around 100 elements there is no longer any difference. +The difference shrinks for shorter lengths, and around 30 elements there is no longer any difference. -## Why use real-to-complex fft? -### Using a complex-to-complex fft -A simple way to get the fft of a rea values vector is to convert it to complex, and using a complex-to-complex fft. +### Why use real-to-complex FFT? +#### Using a complex-to-complex FFT +A simple way to get the FFT of a rea values vector is to convert it to complex, and using a complex-to-complex FFT. -Let's assume `x` is a 6 element long real vector: -```text +Let's assume `x` is a 6 element long real vector: +``` x = [x0r, x1r, x2r, x3r, x4r, x5r] ``` -Converted to complex, using the notation `(xNr, xNi)` for the complex value `xN`, this becomes: -```text +We now convert `x` to complex by adding an imaginary part with value zero. Using the notation `(xNr, xNi)` for the complex value `xN`, this becomes: +``` x_c = [(x0r, 0), (x1r, 0), (x2r, 0), (x3r, 0), (x4r, 0, (x5r, 0)] ``` - -The general result of `X = FFT(x)` is: -```text -X = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)] +Performing a normal complex FFT, the result of `FFT(x_c)` is: +``` +FFT(x_c) = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)] ``` -However, because our `x` was real-valued, some of this is redundant: -```text +But because our `x_c` is real-valued (all imaginary parts are zero), some of this becomes redundant: +``` FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0), (X2r, -X2i), (X1r, -X1i)] ``` -As we can see, the output contains a fair bit of redundant data. But it still takes time for the FFT to calculate these values. Converting the input data to complex also takes a little bit of time. +The last two values are the complex conjugates of `X1` and `X2`. Additionally, `X0i` and `X3i` are zero. +As we can see, the output contains 6 independent values, and the rest is redundant. +But it still takes time for the FFT to calculate the redundant values. +Converting the input data to complex also takes a little bit of time. + +If the length of `x` instead had been 7, result would have been: +``` +FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X3r, -X3i), (X2r, -X2i), (X1r, -X1i)] +``` + +The result is similar, but this time there is no zero at `X3i`. Also in this case we got the same number of indendent values as we started with. -### real-to-complex -Using a real-to-complex fft removes the need for converting the input data to complex. +#### Real-to-complex +Using a real-to-complex FFT removes the need for converting the input data to complex. It also avoids caclulating the redundant output values. -The result is: -```text +The result for 6 elements is: +``` RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0)] ``` -This is the data layout output by the real-to-complex fft, and the one expected as input to the complex-to-real ifft. +The result for 7 elements is: +``` +RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i)] +``` + +This is the data layout output by the real-to-complex FFT, and the one expected as input to the complex-to-real iFFT. -## Scaling +### Scaling RealFFT matches the behaviour of RustFFT and does not normalize the output of either FFT of iFFT. To get normalized results, each element must be scaled by `1/sqrt(length)`. If the processing involves both an FFT and an iFFT step, it is advisable to merge the two normalization steps to a single, by scaling by `1/length`. -## Documentation +### Documentation The full documentation can be generated by rustdoc. To generate and view it run: -```text +``` cargo doc --open ``` -## Benchmarks +### Benchmarks To run a set of benchmarks comparing real-to-complex FFT with standard complex-to-complex, type: -```text +``` cargo bench ``` The results are printed while running, and are compiled into an html report containing much more details. To view, open `target/criterion/report/index.html` in a browser. -## Example +### Example Transform a vector, and then inverse transform the result. ```rust -use realfft::{ComplexToReal, RealToComplex}; +use realfft::RealFftPlanner; use rustfft::num_complex::Complex; use rustfft::num_traits::Zero; -// make dummy input vector, spectrum and output vectors -let mut indata = vec![0.0f64; 256]; -let mut spectrum: Vec> = vec![Complex::zero(); 129]; -let mut outdata: Vec = vec![0.0; 256]; +let length = 256; -//create an FFT and forward transform the input data -let mut r2c = RealToComplex::::new(256).unwrap(); +// make a planner +let mut real_planner = RealFftPlanner::::new(); + +// create a FFT +let r2c = real_planner.plan_fft_forward(length); +// make input and output vectors +let mut indata = r2c.make_input_vec(); +let mut spectrum = r2c.make_output_vec(); + +// Are they the length we expect? +assert_eq!(indata.len(), length); +assert_eq!(spectrum.len(), length/2+1); + +// Forward transform the input data r2c.process(&mut indata, &mut spectrum).unwrap(); -// create an iFFT and inverse transform the spectum -let mut c2r = ComplexToReal::::new(256).unwrap(); -c2r.process(&spectrum, &mut outdata).unwrap(); +// create an iFFT and an output vector +let c2r = real_planner.plan_fft_inverse(length); +let mut outdata = c2r.make_output_vec(); +assert_eq!(outdata.len(), length); + +c2r.process(&mut spectrum, &mut outdata).unwrap(); ``` -## Compatibility +### Compatibility + +The `realfft` crate requires rustc version 1.37 or newer. -The `realfft` crate requires rustc version 1.37 or newer. \ No newline at end of file +License: MIT diff --git a/benches/realfft.rs b/benches/realfft.rs index b46e1a9..83b1c43 100644 --- a/benches/realfft.rs +++ b/benches/realfft.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion extern crate realfft; extern crate rustfft; -use realfft::RealToComplex; +use realfft::RealFftPlanner; use rustfft::num_complex::Complex; /// Times just the FFT execution (not allocation and pre-calculation) @@ -24,7 +24,8 @@ fn bench_fft(b: &mut Bencher, len: usize) { } fn bench_realfft(b: &mut Bencher, len: usize) { - let mut fft = RealToComplex::::new(len).unwrap(); + let mut planner = RealFftPlanner::::new(); + let fft = planner.plan_fft_forward(len); let mut signal = vec![0_f64; len]; let mut spectrum = vec![ @@ -34,27 +35,97 @@ fn bench_realfft(b: &mut Bencher, len: usize) { }; len / 2 + 1 ]; - b.iter(|| fft.process(&mut signal, &mut spectrum)); + let mut scratch = vec![Complex::from(0.0); fft.get_scratch_len()]; + b.iter(|| fft.process_with_scratch(&mut signal, &mut spectrum, &mut scratch)); +} + +/// Times just the FFT execution (not allocation and pre-calculation) +/// for a given length +fn bench_ifft(b: &mut Bencher, len: usize) { + let mut planner = rustfft::FftPlanner::new(); + let fft = planner.plan_fft_inverse(len); + let mut scratch = vec![Complex::from(0.0); fft.get_outofplace_scratch_len()]; + + let mut signal = vec![ + Complex { + re: 0_f64, + im: 0_f64 + }; + len + ]; + let mut spectrum = signal.clone(); + b.iter(|| fft.process_outofplace_with_scratch(&mut signal, &mut spectrum, &mut scratch)); } -fn bench_pow2(c: &mut Criterion) { - let mut group = c.benchmark_group("Powers of 2"); - for i in [64, 128, 256, 512, 4096, 65536].iter() { +fn bench_realifft(b: &mut Bencher, len: usize) { + let mut planner = RealFftPlanner::::new(); + let fft = planner.plan_fft_inverse(len); + + let mut signal = vec![0_f64; len]; + let mut spectrum = vec![ + Complex { + re: 0_f64, + im: 0_f64 + }; + len / 2 + 1 + ]; + let mut scratch = vec![Complex::from(0.0); fft.get_scratch_len()]; + b.iter(|| fft.process_with_scratch(&mut spectrum, &mut signal, &mut scratch)); +} + +fn bench_pow2_fw(c: &mut Criterion) { + let mut group = c.benchmark_group("Fw Powers of 2"); + for i in [8, 16, 32, 64, 128, 256, 1024, 4096, 65536].iter() { group.bench_with_input(BenchmarkId::new("Complex", i), i, |b, i| bench_fft(b, *i)); group.bench_with_input(BenchmarkId::new("Real", i), i, |b, i| bench_realfft(b, *i)); } group.finish(); } -fn bench_pow7(c: &mut Criterion) { - let mut group = c.benchmark_group("Powers of 7"); - for i in [2 * 343, 2 * 2401, 2 * 16807].iter() { - group.bench_with_input(BenchmarkId::new("Complex", i), i, |b, i| bench_fft(b, *i)); - group.bench_with_input(BenchmarkId::new("Real", i), i, |b, i| bench_realfft(b, *i)); +fn bench_pow2_inv(c: &mut Criterion) { + let mut group = c.benchmark_group("Inv Powers of 2"); + for i in [8, 16, 32, 64, 128, 256, 1024, 4096, 65536].iter() { + group.bench_with_input(BenchmarkId::new("Complex", i), i, |b, i| bench_ifft(b, *i)); + group.bench_with_input(BenchmarkId::new("Real", i), i, |b, i| bench_realifft(b, *i)); + } + group.finish(); +} + +//fn bench_pow7(c: &mut Criterion) { +// let mut group = c.benchmark_group("Powers of 7"); +// for i in [2 * 343, 2 * 2401, 2 * 16807].iter() { +// group.bench_with_input(BenchmarkId::new("Complex", i), i, |b, i| bench_fft(b, *i)); +// group.bench_with_input(BenchmarkId::new("Real", i), i, |b, i| bench_realfft(b, *i)); +// } +// group.finish(); +//} + +fn bench_range_fw(c: &mut Criterion) { + let mut group = c.benchmark_group("Fw Range 1022-1025"); + for i in 1022..1026 { + group.bench_with_input(BenchmarkId::new("Complex", i), &i, |b, i| bench_fft(b, *i)); + group.bench_with_input(BenchmarkId::new("Real", i), &i, |b, i| bench_realfft(b, *i)); + } + group.finish(); +} + +fn bench_range_inv(c: &mut Criterion) { + let mut group = c.benchmark_group("Inv Range 1022-1025"); + for i in 1022..1026 { + group.bench_with_input(BenchmarkId::new("Complex", i), &i, |b, i| bench_ifft(b, *i)); + group.bench_with_input(BenchmarkId::new("Real", i), &i, |b, i| { + bench_realifft(b, *i) + }); } group.finish(); } -criterion_group!(benches, bench_pow2, bench_pow7); +criterion_group!( + benches, + bench_pow2_fw, + bench_range_fw, + bench_pow2_inv, + bench_range_inv +); criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index 9f2009e..6f64c2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ //! # RealFFT: Real-to-complex FFT and complex-to-real iFFT based on RustFFT //! -//! This library is a wrapper for RustFFT that enables faster computations when the input data is real. -//! It packs a 2N long real vector into an N long complex vector, which is transformed using a standard FFT. +//! This library is a wrapper for RustFFT that enables performing FFT of real-valued data. +//! The API is designed to be as similar as possible to RustFFT. +//! +//! Using this library instead of RustFFT directly avoids the need of converting real-valued data to complex before performing a FFT. +//! If the length is even, it also enables faster computations by using a complex FFT of half the length. +//! It then packs a 2N long real vector into an N long complex vector, which is transformed using a standard FFT. //! It then post-processes the result to give only the first half of the complex spectrum, as an N+1 long complex vector. //! //! The iFFT goes through the same steps backwards, to transform an N+1 long complex spectrum to a 2N long real result. @@ -9,45 +13,59 @@ //! The speed increase compared to just converting the input to a 2N long complex vector //! and using a 2N long FFT depends on the length f the input data. //! The largest improvements are for long FFTs and for lengths over around 1000 elements there is an improvement of about a factor 2. -//! The difference shrinks for shorter lengths, and around 100 elements there is no longer any difference. +//! The difference shrinks for shorter lengths, and around 30 elements there is no longer any difference. //! -//! ## Why use real-to-complex fft? -//! ### Using a complex-to-complex fft -//! A simple way to get the fft of a rea values vector is to convert it to complex, and using a complex-to-complex fft. +//! ## Why use real-to-complex FFT? +//! ### Using a complex-to-complex FFT +//! A simple way to get the FFT of a rea values vector is to convert it to complex, and using a complex-to-complex FFT. //! //! Let's assume `x` is a 6 element long real vector: //! ```text //! x = [x0r, x1r, x2r, x3r, x4r, x5r] //! ``` //! -//! Converted to complex, using the notation `(xNr, xNi)` for the complex value `xN`, this becomes: +//! We now convert `x` to complex by adding an imaginary part with value zero. Using the notation `(xNr, xNi)` for the complex value `xN`, this becomes: //! ```text //! x_c = [(x0r, 0), (x1r, 0), (x2r, 0), (x3r, 0), (x4r, 0, (x5r, 0)] //! ``` //! +//! Performing a normal complex FFT, the result of `FFT(x_c)` is: +//! ```text +//! FFT(x_c) = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)] +//! ``` //! -//! The general result of `X = FFT(x)` is: +//! But because our `x_c` is real-valued (all imaginary parts are zero), some of this becomes redundant: //! ```text -//! X = [(X0r, X0i), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X4r, X4i), (X5r, X5i)] +//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0), (X2r, -X2i), (X1r, -X1i)] //! ``` //! -//! However, because our `x` was real-valued, some of this is redundant: +//! The last two values are the complex conjugates of `X1` and `X2`. Additionally, `X0i` and `X3i` are zero. +//! As we can see, the output contains 6 independent values, and the rest is redundant. +//! But it still takes time for the FFT to calculate the redundant values. +//! Converting the input data to complex also takes a little bit of time. +//! +//! If the length of `x` instead had been 7, result would have been: //! ```text -//! FFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0), (X2r, -X2i), (X1r, -X1i)] +//! FFT(x_c) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i), (X3r, -X3i), (X2r, -X2i), (X1r, -X1i)] //! ``` //! -//! As we can see, the output contains a fair bit of redundant data. But it still takes time for the FFT to calculate these values. Converting the input data to complex also takes a little bit of time. +//! The result is similar, but this time there is no zero at `X3i`. Also in this case we got the same number of indendent values as we started with. //! -//! ### real-to-complex -//! Using a real-to-complex fft removes the need for converting the input data to complex. +//! ### Real-to-complex +//! Using a real-to-complex FFT removes the need for converting the input data to complex. //! It also avoids caclulating the redundant output values. //! -//! The result is: +//! The result for 6 elements is: //! ```text //! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, 0)] //! ``` //! -//! This is the data layout output by the real-to-complex fft, and the one expected as input to the complex-to-real ifft. +//! The result for 7 elements is: +//! ```text +//! RealFFT(x) = [(X0r, 0), (X1r, X1i), (X2r, X2i), (X3r, X3i)] +//! ``` +//! +//! This is the data layout output by the real-to-complex FFT, and the one expected as input to the complex-to-real iFFT. //! //! ## Scaling //! RealFFT matches the behaviour of RustFFT and does not normalize the output of either FFT of iFFT. To get normalized results, each element must be scaled by `1/sqrt(length)`. If the processing involves both an FFT and an iFFT step, it is advisable to merge the two normalization steps to a single, by scaling by `1/length`. @@ -71,33 +89,51 @@ //! ## Example //! Transform a vector, and then inverse transform the result. //! ``` -//! use realfft::{ComplexToReal, RealToComplex}; +//! use realfft::RealFftPlanner; //! use rustfft::num_complex::Complex; //! use rustfft::num_traits::Zero; //! -//! // make dummy input vector, spectrum and output vectors -//! let mut indata = vec![0.0f64; 256]; -//! let mut spectrum: Vec> = vec![Complex::zero(); 129]; -//! let mut outdata: Vec = vec![0.0; 256]; +//! let length = 256; +//! +//! // make a planner +//! let mut real_planner = RealFftPlanner::::new(); +//! +//! // create a FFT +//! let r2c = real_planner.plan_fft_forward(length); +//! // make input and output vectors +//! let mut indata = r2c.make_input_vec(); +//! let mut spectrum = r2c.make_output_vec(); //! -//! //create an FFT and forward transform the input data -//! let mut r2c = RealToComplex::::new(256).unwrap(); +//! // Are they the length we expect? +//! assert_eq!(indata.len(), length); +//! assert_eq!(spectrum.len(), length/2+1); +//! +//! // Forward transform the input data //! r2c.process(&mut indata, &mut spectrum).unwrap(); //! -//! // create an iFFT and inverse transform the spectum -//! let mut c2r = ComplexToReal::::new(256).unwrap(); -//! c2r.process(&spectrum, &mut outdata).unwrap(); +//! // create an iFFT and an output vector +//! let c2r = real_planner.plan_fft_inverse(length); +//! let mut outdata = c2r.make_output_vec(); +//! assert_eq!(outdata.len(), length); +//! +//! c2r.process(&mut spectrum, &mut outdata).unwrap(); //! ``` //! //! ## Compatibility //! //! The `realfft` crate requires rustc version 1.37 or newer. +pub use rustfft::num_complex; +pub use rustfft::num_traits; +pub use rustfft::FftNum; + use rustfft::num_complex::Complex; use rustfft::num_traits::Zero; use rustfft::FftPlanner; +use std::collections::HashMap; use std::error; use std::fmt; +use std::sync::Arc; type Res = Result>; @@ -127,335 +163,851 @@ impl FftError { } } -/// An FFT that takes a real-valued input vector of length 2*N and transforms it to a complex -/// spectrum of length N+1. -pub struct RealToComplex { - sin_cos: Vec<(T, T)>, +fn compute_twiddle(index: usize, fft_len: usize) -> Complex { + let constant = -2f64 * std::f64::consts::PI / fft_len as f64; + let angle = constant * index as f64; + Complex { + re: T::from_f64(angle.cos()).unwrap(), + im: T::from_f64(angle.sin()).unwrap(), + } +} + +pub struct RealToComplexOdd { length: usize, fft: std::sync::Arc>, - buffer_out: Vec>, - scratch: Vec>, + scratch_len: usize, } -/// An FFT that takes a real-valued input vector of length 2*N and transforms it to a complex -/// spectrum of length N+1. -pub struct ComplexToReal { - sin_cos: Vec<(T, T)>, +pub struct RealToComplexEven { + twiddles: Vec>, + length: usize, + fft: std::sync::Arc>, + scratch_len: usize, +} + +pub struct ComplexToRealOdd { length: usize, fft: std::sync::Arc>, - buffer_in: Vec>, - scratch: Vec>, + scratch_len: usize, } -fn zip4( - a: A, - b: B, - c: C, - d: D, -) -> impl Iterator +pub struct ComplexToRealEven { + twiddles: Vec>, + length: usize, + fft: std::sync::Arc>, + scratch_len: usize, +} + +/// An FFT that takes a real-valued input vector of length 2*N and transforms it to a complex +/// spectrum of length N+1. +#[allow(clippy::len_without_is_empty)] +pub trait RealToComplex { + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [T], output: &mut [Complex]) -> Res<()>; + + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [T], + output: &mut [Complex], + scratch: &mut [Complex], + ) -> Res<()>; + + /// Get the length of the scratch space needed for `process_with_scratch`. + fn get_scratch_len(&self) -> usize; + + /// Get the number of points that this FFT can process. + fn len(&self) -> usize; + + /// Convenience method to make an input vector of the right type and length. + fn make_input_vec(&self) -> Vec; + + /// Convenience method to make an output vector of the right type and length. + fn make_output_vec(&self) -> Vec>; + + /// Convenience method to make a scratch vector of the right type and length. + fn make_scratch_vec(&self) -> Vec>; +} + +/// An FFT that takes a complex-valued input vector of length N+1 and transforms it to a complex +/// spectrum of length 2*N. +#[allow(clippy::len_without_is_empty)] +pub trait ComplexToReal { + /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [Complex], output: &mut [T]) -> Res<()>; + + /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the 2*N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [Complex], + output: &mut [T], + scratch: &mut [Complex], + ) -> Res<()>; + + /// Get the length of the scratch space needed for `process_with_scratch`. + fn get_scratch_len(&self) -> usize; + + /// Get the number of points that this FFT can process. + fn len(&self) -> usize; + + /// Convenience method to make an input vector of the right type and length. + fn make_input_vec(&self) -> Vec>; + + /// Convenience method to make an output vector of the right type and length. + fn make_output_vec(&self) -> Vec; + + /// Convenience method to make a scratch vector of the right type and length. + fn make_scratch_vec(&self) -> Vec>; +} + +fn zip3(a: A, b: B, c: C) -> impl Iterator where A: IntoIterator, B: IntoIterator, C: IntoIterator, - D: IntoIterator, { a.into_iter() - .zip(b.into_iter().zip(c.into_iter().zip(d))) - .map(|(w, (x, (y, z)))| (w, x, y, z)) + .zip(b.into_iter().zip(c)) + .map(|(x, (y, z))| (x, y, z)) } -macro_rules! impl_r2c { - ($ft:ty) => { - impl RealToComplex<$ft> { - /// Create a new RealToComplex FFT for input data of a given length. Returns an error if the length is not even. - pub fn new(length: usize) -> Res { - if length % 2 > 0 { - return Err(Box::new(FftError::new("Length must be even"))); - } - let buffer_out = vec![Complex::zero(); length / 2 + 1]; - let mut sin_cos = Vec::with_capacity(length / 2); - let pi = std::f64::consts::PI as $ft; - for k in 0..length / 2 { - let sin = (k as $ft * pi / (length / 2) as $ft).sin(); - let cos = (k as $ft * pi / (length / 2) as $ft).cos(); - sin_cos.push((sin, cos)); - } - let mut fft_planner = FftPlanner::<$ft>::new(); - let fft = fft_planner.plan_fft_forward(length / 2); - let scratch = vec![Complex::zero(); fft.get_outofplace_scratch_len()]; - Ok(RealToComplex { - sin_cos, - length, - fft, - buffer_out, - scratch, - }) - } +/// A planner is used to create FFTs. It caches results internally, +/// so when making more than one FFT it is advisable to reuse the same planner. +pub struct RealFftPlanner { + planner: FftPlanner, + r2c_cache: HashMap>>, + c2r_cache: HashMap>>, +} - /// Transform a vector of 2*N real-valued samples, storing the result in the N+1 element long complex output vector. - /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. - pub fn process(&mut self, input: &mut [$ft], output: &mut [Complex<$ft>]) -> Res<()> { - if input.len() != self.length { - return Err(Box::new(FftError::new( - format!( - "Wrong length of input, expected {}, got {}", - self.length, - input.len() - ) - .as_str(), - ))); - } - if output.len() != (self.length / 2 + 1) { - return Err(Box::new(FftError::new( - format!( - "Wrong length of output, expected {}, got {}", - self.length / 2 + 1, - input.len() - ) - .as_str(), - ))); - } - let fftlen = self.length / 2; - //for (val, buf) in input.chunks(2).take(fftlen).zip(self.buffer_in.iter_mut()) { - // *buf = Complex::new(val[0], val[1]); - //} - let mut buf_in = unsafe { - let ptr = input.as_mut_ptr() as *mut Complex<$ft>; - let len = input.len(); - std::slice::from_raw_parts_mut(ptr, len / 2) - }; +impl RealFftPlanner { + /// Create a new planner. + pub fn new() -> Self { + let planner = FftPlanner::::new(); + Self { + r2c_cache: HashMap::new(), + c2r_cache: HashMap::new(), + planner, + } + } - // FFT and store result in buffer_out - self.fft.process_outofplace_with_scratch( - &mut buf_in, - &mut self.buffer_out[0..fftlen], - &mut self.scratch, - ); - - self.buffer_out[fftlen] = self.buffer_out[0]; - - for (&buf, &buf_rev, &(sin, cos), out) in zip4( - &self.buffer_out, - self.buffer_out.iter().rev(), - &self.sin_cos, - &mut output[..], - ) { - let xr = 0.5 - * ((buf.re + buf_rev.re) + cos * (buf.im + buf_rev.im) - - sin * (buf.re - buf_rev.re)); - let xi = 0.5 - * ((buf.im - buf_rev.im) - - sin * (buf.im + buf_rev.im) - - cos * (buf.re - buf_rev.re)); - *out = Complex::new(xr, xi); - } - output[fftlen] = Complex::new(self.buffer_out[0].re - self.buffer_out[0].im, 0.0); - Ok(()) - } + /// Plan a Real-to-Complex forward FFT. Returns the FFT in a shared reference. + /// If requesting a second FFT of the same length, this will return a new reference to the already existing one. + pub fn plan_fft_forward(&mut self, len: usize) -> Arc> { + if let Some(fft) = self.r2c_cache.get(&len) { + Arc::clone(&fft) + } else { + let fft = if len % 2 > 0 { + Arc::new(RealToComplexOdd::new(len, &mut self.planner)) as Arc> + } else { + Arc::new(RealToComplexEven::new(len, &mut self.planner)) + as Arc> + }; + self.r2c_cache.insert(len, Arc::clone(&fft)); + fft } - }; + } + + /// Plan a Complex-to-Real inverse FFT. Returns the FFT in a shared reference. + /// If requesting a second FFT of the same length, this will return a new reference to the already existing one. + pub fn plan_fft_inverse(&mut self, len: usize) -> Arc> { + if let Some(fft) = self.c2r_cache.get(&len) { + Arc::clone(&fft) + } else { + let fft = if len % 2 > 0 { + Arc::new(ComplexToRealOdd::new(len, &mut self.planner)) as Arc> + } else { + Arc::new(ComplexToRealEven::new(len, &mut self.planner)) + as Arc> + }; + self.c2r_cache.insert(len, Arc::clone(&fft)); + fft + } + } } -impl_r2c!(f64); -impl_r2c!(f32); - -macro_rules! impl_c2r { - ($ft:ty) => { - /// Create a new ComplexToReal iFFT for output data of a given length. Returns an error if the length is not even. - impl ComplexToReal<$ft> { - pub fn new(length: usize) -> Res { - if length % 2 > 0 { - return Err(Box::new(FftError::new("Length must be even"))); - } - let buffer_in = vec![Complex::zero(); length / 2]; - let mut sin_cos = Vec::with_capacity(length / 2); - let pi = std::f64::consts::PI as $ft; - for k in 0..length / 2 { - let sin = (k as $ft * pi / (length / 2) as $ft).sin(); - let cos = (k as $ft * pi / (length / 2) as $ft).cos(); - sin_cos.push((sin, cos)); - } - let mut fft_planner = FftPlanner::<$ft>::new(); - let fft = fft_planner.plan_fft_inverse(length / 2); - let scratch = vec![Complex::zero(); fft.get_outofplace_scratch_len()]; - Ok(ComplexToReal { - sin_cos, - length, - fft, - buffer_in, - scratch, - }) - } - /// Transform a complex spectrum of N+1 values and store the real result in the 2*N long output. - pub fn process(&mut self, input: &[Complex<$ft>], output: &mut [$ft]) -> Res<()> { - if input.len() != (self.length / 2 + 1) { - return Err(Box::new(FftError::new( - format!( - "Wrong length of input, expected {}, got {}", - self.length / 2 + 1, - input.len() - ) - .as_str(), - ))); - } - if output.len() != self.length { - return Err(Box::new(FftError::new( - format!( - "Wrong length of output, expected {}, got {}", - self.length, - input.len() - ) - .as_str(), - ))); - } - - for (&buf, &buf_rev, &(sin, cos), fft_input) in zip4( - input, - input.iter().rev(), - &self.sin_cos, - &mut self.buffer_in[..], - ) { - let xr = (buf.re + buf_rev.re) - - cos * (buf.im + buf_rev.im) - - sin * (buf.re - buf_rev.re); - let xi = (buf.im - buf_rev.im) + cos * (buf.re - buf_rev.re) - - sin * (buf.im + buf_rev.im); - *fft_input = Complex::new(xr, xi); - } - - // FFT and store result in buffer_out - let mut buf_out = unsafe { - let ptr = output.as_mut_ptr() as *mut Complex<$ft>; - let len = output.len(); - std::slice::from_raw_parts_mut(ptr, len / 2) +impl Default for RealFftPlanner { + fn default() -> Self { + Self::new() + } +} + +impl RealToComplexOdd { + /// Create a new RealToComplex FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT. + /// Panics if the length is not odd. + pub fn new(length: usize, fft_planner: &mut FftPlanner) -> Self { + if length % 2 == 0 { + panic!("Length must be odd, got {}", length,); + } + let fft = fft_planner.plan_fft_forward(length); + let scratch_len = fft.get_inplace_scratch_len() + length; + RealToComplexOdd { + length, + fft, + scratch_len, + } + } +} + +impl RealToComplex for RealToComplexOdd { + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [T], output: &mut [Complex]) -> Res<()> { + let mut scratch = self.make_scratch_vec(); + self.process_with_scratch(input, output, &mut scratch) + } + + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 (with N/2 rounded down) element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [T], + output: &mut [Complex], + scratch: &mut [Complex], + ) -> Res<()> { + if input.len() != self.length { + return Err(Box::new(FftError::new( + format!( + "Wrong length of input, expected {}, got {}", + self.length, + input.len() + ) + .as_str(), + ))); + } + if output.len() != (self.length / 2 + 1) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of output, expected {}, got {}", + self.length / 2 + 1, + input.len() + ) + .as_str(), + ))); + } + if scratch.len() != (self.scratch_len) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of scratch, expected {}, got {}", + self.scratch_len / 2 + 1, + scratch.len() + ) + .as_str(), + ))); + } + let (buffer, fft_scratch) = scratch.split_at_mut(self.length); + + for (val, buf) in input.iter().zip(buffer.iter_mut()) { + *buf = Complex::new(*val, T::zero()); + } + // FFT and store result in buffer_out + self.fft.process_with_scratch(buffer, fft_scratch); + output.copy_from_slice(&buffer[0..self.length / 2 + 1]); + Ok(()) + } + + fn get_scratch_len(&self) -> usize { + self.scratch_len + } + + fn len(&self) -> usize { + self.length + } + + fn make_input_vec(&self) -> Vec { + vec![T::zero(); self.len()] + } + + fn make_output_vec(&self) -> Vec> { + vec![Complex::zero(); self.len() / 2 + 1] + } + + fn make_scratch_vec(&self) -> Vec> { + vec![Complex::zero(); self.get_scratch_len()] + } +} + +impl RealToComplexEven { + /// Create a new RealToComplex FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT. + /// Panics if the length is not even. + pub fn new(length: usize, fft_planner: &mut FftPlanner) -> Self { + if length % 2 > 0 { + panic!("Length must be even, got {}", length,); + } + let twiddle_count = if length % 4 == 0 { + length / 4 + } else { + length / 4 + 1 + }; + let twiddles: Vec> = (1..twiddle_count) + .map(|i| compute_twiddle(i, length) * T::from_f64(0.5).unwrap()) + .collect(); + //let mut fft_planner = FftPlanner::::new(); + let fft = fft_planner.plan_fft_forward(length / 2); + let scratch_len = fft.get_outofplace_scratch_len(); + RealToComplexEven { + twiddles, + length, + fft, + scratch_len, + } + } +} + +impl RealToComplex for RealToComplexEven { + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [T], output: &mut [Complex]) -> Res<()> { + let mut scratch = self.make_scratch_vec(); + self.process_with_scratch(input, output, &mut scratch) + } + + /// Transform a vector of N real-valued samples, storing the result in the N/2+1 element long complex output vector. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [T], + output: &mut [Complex], + scratch: &mut [Complex], + ) -> Res<()> { + if input.len() != self.length { + return Err(Box::new(FftError::new( + format!( + "Wrong length of input, expected {}, got {}", + self.length, + input.len() + ) + .as_str(), + ))); + } + if output.len() != (self.length / 2 + 1) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of output, expected {}, got {}", + self.length / 2 + 1, + input.len() + ) + .as_str(), + ))); + } + if scratch.len() != (self.scratch_len) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of scratch, expected {}, got {}", + self.scratch_len / 2 + 1, + scratch.len() + ) + .as_str(), + ))); + } + + let fftlen = self.length / 2; + let mut buf_in = unsafe { + let ptr = input.as_mut_ptr() as *mut Complex; + let len = input.len(); + std::slice::from_raw_parts_mut(ptr, len / 2) + }; + + // FFT and store result in buffer_out + self.fft + .process_outofplace_with_scratch(&mut buf_in, &mut output[0..fftlen], scratch); + let (mut output_left, mut output_right) = output.split_at_mut(output.len() / 2); + + // The first and last element don't require any twiddle factors, so skip that work + match (output_left.first_mut(), output_right.last_mut()) { + (Some(first_element), Some(last_element)) => { + // The first and last elements are just a sum and difference of the first value's real and imaginary values + let first_value = *first_element; + *first_element = Complex { + re: first_value.re + first_value.im, + im: T::zero(), + }; + *last_element = Complex { + re: first_value.re - first_value.im, + im: T::zero(), }; - self.fft.process_outofplace_with_scratch( - &mut self.buffer_in, - &mut buf_out, - &mut self.scratch, - ); - Ok(()) + + // Chop the first and last element off of our slices so that the loop below doesn't have to deal with them + output_left = &mut output_left[1..]; + let right_len = output_right.len(); + output_right = &mut output_right[..right_len - 1]; + } + _ => { + return Ok(()); + } + } + // Loop over the remaining elements and apply twiddle factors on them + for (twiddle, out, out_rev) in zip3( + self.twiddles.iter(), + output_left.iter_mut(), + output_right.iter_mut().rev(), + ) { + let sum = *out + *out_rev; + let diff = *out - *out_rev; + let half = T::from_f64(0.5).unwrap(); + // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning + // and one for the end. But the twiddle factor for the end is jsut the twiddle for the beginning, with the + // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of + // multiplications in half + let twiddled_re_sum = sum * twiddle.re; + let twiddled_im_sum = sum * twiddle.im; + let twiddled_re_diff = diff * twiddle.re; + let twiddled_im_diff = diff * twiddle.im; + let half_sum_re = half * sum.re; + let half_diff_im = half * diff.im; + + let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re; + let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re; + + // We finally have all the data we need to write the transformed data back out where we found it + *out = Complex { + re: half_sum_re + output_twiddled_real, + im: half_diff_im + output_twiddled_im, + }; + + *out_rev = Complex { + re: half_sum_re - output_twiddled_real, + im: output_twiddled_im - half_diff_im, + }; + } + + // If the output len is odd, the loop above can't postprocess the centermost element, so handle that separately + if output.len() % 2 == 1 { + if let Some(center_element) = output.get_mut(output.len() / 2) { + center_element.im = -center_element.im; } } - }; + Ok(()) + } + fn get_scratch_len(&self) -> usize { + self.scratch_len + } + + fn len(&self) -> usize { + self.length + } + + fn make_input_vec(&self) -> Vec { + vec![T::zero(); self.len()] + } + + fn make_output_vec(&self) -> Vec> { + vec![Complex::zero(); self.len() / 2 + 1] + } + + fn make_scratch_vec(&self) -> Vec> { + vec![Complex::zero(); self.get_scratch_len()] + } } -impl_c2r!(f64); -impl_c2r!(f32); -#[cfg(test)] -mod tests { - use crate::{ComplexToReal, RealToComplex}; - use rustfft::num_complex::Complex; - use rustfft::num_traits::Zero; - use rustfft::FftPlanner; +impl ComplexToRealOdd { + /// Create a new ComplexToReal FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT. + /// Panics if the length is not odd. + pub fn new(length: usize, fft_planner: &mut FftPlanner) -> Self { + if length % 2 == 0 { + panic!("Length must be odd, got {}", length,); + } + //let mut fft_planner = FftPlanner::::new(); + let fft = fft_planner.plan_fft_inverse(length); + let scratch_len = length + fft.get_inplace_scratch_len(); + ComplexToRealOdd { + length, + fft, + scratch_len, + } + } +} - fn compare_complex(a: &[Complex], b: &[Complex], tol: f64) -> bool { - a.iter().zip(b.iter()).fold(true, |eq, (val_a, val_b)| { - eq && (val_a.re - val_b.re).abs() < tol && (val_a.im - val_b.im).abs() < tol - }) +impl ComplexToReal for ComplexToRealOdd { + /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [Complex], output: &mut [T]) -> Res<()> { + let mut scratch = self.make_scratch_vec(); + self.process_with_scratch(input, output, &mut scratch) } - fn compare_f64(a: &[f64], b: &[f64], tol: f64) -> bool { - a.iter() - .zip(b.iter()) - .fold(true, |eq, (val_a, val_b)| eq && (val_a - val_b).abs() < tol) + /// Transform a complex spectrum of N/2+1 (with N/2 rounded down) values and store the real result in the N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [Complex], + output: &mut [T], + scratch: &mut [Complex], + ) -> Res<()> { + if input.len() != (self.length / 2 + 1) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of input, expected {}, got {}", + self.length / 2 + 1, + input.len() + ) + .as_str(), + ))); + } + if output.len() != self.length { + return Err(Box::new(FftError::new( + format!( + "Wrong length of output, expected {}, got {}", + self.length, + input.len() + ) + .as_str(), + ))); + } + if scratch.len() != (self.scratch_len) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of scratch, expected {}, got {}", + self.scratch_len / 2 + 1, + scratch.len() + ) + .as_str(), + ))); + } + + let (buffer, fft_scratch) = scratch.split_at_mut(self.length); + + buffer[0..input.len()].copy_from_slice(&input); + for (buf, val) in buffer + .iter_mut() + .rev() + .take(self.length / 2) + .zip(input.iter().skip(1)) + { + *buf = val.conj(); + //buf.im = -val.im; + } + self.fft.process_with_scratch(buffer, fft_scratch); + for (val, out) in buffer.iter().zip(output.iter_mut()) { + *out = val.re; + } + Ok(()) } - // Compare RealToComplex with standard FFT - #[test] - fn real_to_complex() { - let mut indata = vec![0.0f64; 256]; - for (i, val) in indata.iter_mut().enumerate() { - *val = i as f64; + fn get_scratch_len(&self) -> usize { + self.scratch_len + } + + fn len(&self) -> usize { + self.length + } + + fn make_input_vec(&self) -> Vec> { + vec![Complex::zero(); self.len() / 2 + 1] + } + + fn make_output_vec(&self) -> Vec { + vec![T::zero(); self.len()] + } + + fn make_scratch_vec(&self) -> Vec> { + vec![Complex::zero(); self.get_scratch_len()] + } +} + +impl ComplexToRealEven { + /// Create a new ComplexToReal FFT for input data of a given length, and uses the given FftPlanner to build the inner FFT. + /// Panics if the length is not even. + pub fn new(length: usize, fft_planner: &mut FftPlanner) -> Self { + if length % 2 > 0 { + panic!("Length must be even, got {}", length,); } - let mut rustfft_check = indata - .iter() - .map(|val| Complex::from(val)) - .collect::>>(); - let mut fft_planner = FftPlanner::::new(); - let fft = fft_planner.plan_fft_forward(256); + let twiddle_count = if length % 4 == 0 { + length / 4 + } else { + length / 4 + 1 + }; + let twiddles: Vec> = (1..twiddle_count) + .map(|i| compute_twiddle(i, length).conj()) + .collect(); + //let mut fft_planner = FftPlanner::::new(); + let fft = fft_planner.plan_fft_inverse(length / 2); + let scratch_len = fft.get_outofplace_scratch_len(); + ComplexToRealEven { + twiddles, + length, + fft, + scratch_len, + } + } +} +impl ComplexToReal for ComplexToRealEven { + /// Transform a complex spectrum of N/2+1 values and store the real result in the N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also allocates additional scratch space as needed. + /// An error is returned if any of the given slices has the wrong length. + fn process(&self, input: &mut [Complex], output: &mut [T]) -> Res<()> { + let mut scratch = self.make_scratch_vec(); + self.process_with_scratch(input, output, &mut scratch) + } - let mut r2c = RealToComplex::::new(256).unwrap(); - let mut out_a: Vec> = vec![Complex::zero(); 129]; + /// Transform a complex spectrum of N/2+1 values and store the real result in the N long output. + /// The input buffer is used as scratch space, so the contents of input should be considered garbage after calling. + /// It also uses the provided scratch vector instead of allocating, which will be faster if it is called more than once. + /// An error is returned if any of the given slices has the wrong length. + fn process_with_scratch( + &self, + input: &mut [Complex], + output: &mut [T], + scratch: &mut [Complex], + ) -> Res<()> { + if input.len() != (self.length / 2 + 1) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of input, expected {}, got {}", + self.length / 2 + 1, + input.len() + ) + .as_str(), + ))); + } + if output.len() != self.length { + return Err(Box::new(FftError::new( + format!( + "Wrong length of output, expected {}, got {}", + self.length, + input.len() + ) + .as_str(), + ))); + } + if scratch.len() != (self.scratch_len) { + return Err(Box::new(FftError::new( + format!( + "Wrong length of scratch, expected {}, got {}", + self.scratch_len / 2 + 1, + scratch.len() + ) + .as_str(), + ))); + } + let (mut input_left, mut input_right) = input.split_at_mut(input.len() / 2); + + // We have to preprocess the input in-place before we send it to the FFT. + // The first and centermost values have to be preprocessed separately from the rest, so do that now + match (input_left.first_mut(), input_right.last_mut()) { + (Some(first_input), Some(last_input)) => { + let first_sum = *first_input + *last_input; + let first_diff = *first_input - *last_input; + + *first_input = Complex { + re: first_sum.re - first_sum.im, + im: first_diff.re - first_diff.im, + }; - fft.process(&mut rustfft_check); - r2c.process(&mut indata, &mut out_a).unwrap(); - assert!(compare_complex( - &out_a[0..129], - &rustfft_check[0..129], - 1.0e-9 - )); + input_left = &mut input_left[1..]; + let right_len = input_right.len(); + input_right = &mut input_right[..right_len - 1]; + } + _ => return Ok(()), + }; + + // now, in a loop, preprocess the rest of the elements 2 at a time + for (twiddle, fft_input, fft_input_rev) in zip3( + self.twiddles.iter(), + input_left.iter_mut(), + input_right.iter_mut().rev(), + ) { + let sum = *fft_input + *fft_input_rev; + let diff = *fft_input - *fft_input_rev; + + // Apply twiddle factors. Theoretically we'd have to load 2 separate twiddle factors here, one for the beginning + // and one for the end. But the twiddle factor for the end is jsut the twiddle for the beginning, with the + // real part negated. Since it's the same twiddle, we can factor out a ton of math ops and cut the number of + // multiplications in half + let twiddled_re_sum = sum * twiddle.re; + let twiddled_im_sum = sum * twiddle.im; + let twiddled_re_diff = diff * twiddle.re; + let twiddled_im_diff = diff * twiddle.im; + + let output_twiddled_real = twiddled_re_sum.im + twiddled_im_diff.re; + let output_twiddled_im = twiddled_im_sum.im - twiddled_re_diff.re; + + // We finally have all the data we need to write our preprocessed data back where we got it from + *fft_input = Complex { + re: sum.re - output_twiddled_real, + im: diff.im - output_twiddled_im, + }; + *fft_input_rev = Complex { + re: sum.re + output_twiddled_real, + im: -output_twiddled_im - diff.im, + } + } + + // If the output len is odd, the loop above can't preprocess the centermost element, so handle that separately + if input.len() % 2 == 1 { + let center_element = input[input.len() / 2]; + let doubled = center_element + center_element; + input[input.len() / 2] = doubled.conj(); + } + + // FFT and store result in buffer_out + let mut buf_out = unsafe { + let ptr = output.as_mut_ptr() as *mut Complex; + let len = output.len(); + std::slice::from_raw_parts_mut(ptr, len / 2) + }; + self.fft.process_outofplace_with_scratch( + &mut input[..output.len() / 2], + &mut buf_out, + scratch, + ); + Ok(()) } - // Compare ComplexToReal with standard iFFT - #[test] - fn complex_to_real() { - let mut indata = vec![Complex::::zero(); 256]; - indata[0] = Complex::new(1.0, 0.0); - indata[1] = Complex::new(1.0, 0.4); - indata[255] = Complex::new(1.0, -0.4); - indata[3] = Complex::new(0.3, 0.2); - indata[253] = Complex::new(0.3, -0.2); - let mut rustfft_check = indata.clone(); + fn get_scratch_len(&self) -> usize { + self.scratch_len + } - let mut fft_planner = FftPlanner::::new(); - let fft = fft_planner.plan_fft_inverse(256); + fn len(&self) -> usize { + self.length + } - let mut c2r = ComplexToReal::::new(256).unwrap(); - let mut out_a: Vec = vec![0.0; 256]; + fn make_input_vec(&self) -> Vec> { + vec![Complex::zero(); self.len() / 2 + 1] + } - c2r.process(&indata[0..129], &mut out_a).unwrap(); - fft.process(&mut rustfft_check); + fn make_output_vec(&self) -> Vec { + vec![T::zero(); self.len()] + } - let check_real = rustfft_check.iter().map(|val| val.re).collect::>(); - assert!(compare_f64(&out_a, &check_real, 1.0e-9)); + fn make_scratch_vec(&self) -> Vec> { + vec![Complex::zero(); self.get_scratch_len()] } +} - // Compare RealToComplex with standard FFT - #[test] - fn real_to_complex_odd() { - let mut indata = vec![0.0f64; 254]; - indata[0] = 1.0; - indata[3] = 0.5; - let mut rustfft_check = indata - .iter() - .map(|val| Complex::from(val)) - .collect::>>(); - let mut fft_planner = FftPlanner::::new(); - let fft = fft_planner.plan_fft_forward(254); - - let mut r2c = RealToComplex::::new(254).unwrap(); - let mut out_a: Vec> = vec![Complex::zero(); 128]; - - fft.process(&mut rustfft_check); - r2c.process(&mut indata, &mut out_a).unwrap(); - assert!(compare_complex( - &out_a[0..128], - &rustfft_check[0..128], - 1.0e-9 - )); +#[cfg(test)] +mod tests { + use crate::RealFftPlanner; + use rand::Rng; + use rustfft::num_complex::Complex; + use rustfft::num_traits::Zero; + use rustfft::FftPlanner; + + // get the largest difference + fn compare_complex(a: &[Complex], b: &[Complex]) -> f64 { + a.iter().zip(b.iter()).fold(0.0, |maxdiff, (val_a, val_b)| { + let diff = (val_a - val_b).norm(); + if maxdiff > diff { + maxdiff + } else { + diff + } + }) + } + + // get the largest difference + fn compare_f64(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).fold(0.0, |maxdiff, (val_a, val_b)| { + let diff = (val_a - val_b).abs(); + if maxdiff > diff { + maxdiff + } else { + diff + } + }) } // Compare ComplexToReal with standard iFFT #[test] - fn complex_to_real_odd() { - let mut indata = vec![Complex::::zero(); 254]; - indata[0] = Complex::new(1.0, 0.0); - indata[1] = Complex::new(1.0, 0.4); - indata[253] = Complex::new(1.0, -0.4); - indata[3] = Complex::new(0.3, 0.2); - indata[251] = Complex::new(0.3, -0.2); - let mut rustfft_check = indata.clone(); - - let mut fft_planner = FftPlanner::::new(); - let fft = fft_planner.plan_fft_inverse(254); - - let mut c2r = ComplexToReal::::new(254).unwrap(); - let mut out_a: Vec = vec![0.0; 254]; - - c2r.process(&indata[0..128], &mut out_a).unwrap(); - fft.process(&mut rustfft_check); - let check_real = rustfft_check.iter().map(|val| val.re).collect::>(); - assert!(compare_f64(&out_a[0..128], &check_real[0..128], 1.0e-9)); + fn complex_to_real() { + for length in 1..1000 { + let mut real_planner = RealFftPlanner::::new(); + let c2r = real_planner.plan_fft_inverse(length); + let mut out_a = c2r.make_output_vec(); + let mut indata = c2r.make_input_vec(); + let mut rustfft_check: Vec> = vec![Complex::zero(); length]; + let mut rng = rand::thread_rng(); + for val in indata.iter_mut() { + *val = Complex::new(rng.gen::(), rng.gen::()); + } + indata[0].im = 0.0; + if length % 2 == 0 { + indata[length / 2].im = 0.0; + } + for (val_long, val) in rustfft_check + .iter_mut() + .take(length / 2 + 1) + .zip(indata.iter()) + { + *val_long = *val; + } + for (val_long, val) in rustfft_check + .iter_mut() + .rev() + .take(length / 2) + .zip(indata.iter().skip(1)) + { + *val_long = val.conj(); + } + let mut fft_planner = FftPlanner::::new(); + let fft = fft_planner.plan_fft_inverse(length); + + c2r.process(&mut indata, &mut out_a).unwrap(); + fft.process(&mut rustfft_check); + + let check_real = rustfft_check.iter().map(|val| val.re).collect::>(); + let maxdiff = compare_f64(&out_a, &check_real); + assert!( + maxdiff < 1.0e-9, + "Length: {}, too large error: {}", + length, + maxdiff + ); + } + } + + // Compare RealToComplex with standard FFT + #[test] + fn real_to_complex() { + for length in 1..1000 { + let mut real_planner = RealFftPlanner::::new(); + let r2c = real_planner.plan_fft_forward(length); + let mut out_a = r2c.make_output_vec(); + let mut indata = r2c.make_input_vec(); + let mut rng = rand::thread_rng(); + for val in indata.iter_mut() { + *val = rng.gen::(); + } + let mut rustfft_check = indata + .iter() + .map(|val| Complex::from(val)) + .collect::>>(); + let mut fft_planner = FftPlanner::::new(); + let fft = fft_planner.plan_fft_forward(length); + + fft.process(&mut rustfft_check); + r2c.process(&mut indata, &mut out_a).unwrap(); + let maxdiff = compare_complex(&out_a, &rustfft_check[0..(length / 2 + 1)]); + assert!( + maxdiff < 1.0e-9, + "Length: {}, too large error: {}", + length, + maxdiff + ); + } } }