From b7b8067eb4ee388eb2222caede2d9449360e3e03 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 16:33:45 -0800 Subject: [PATCH 01/13] Optimized big SSE butterflies to be more cache-friendly --- benches/bench_rustfft_sse.rs | 2 + src/sse/sse_butterflies.rs | 1856 +++++++++++++--------------------- src/sse/sse_utils.rs | 42 + src/sse/sse_vector.rs | 13 + 4 files changed, 777 insertions(+), 1136 deletions(-) diff --git a/benches/bench_rustfft_sse.rs b/benches/bench_rustfft_sse.rs index 962e7e56..fe80fdba 100644 --- a/benches/bench_rustfft_sse.rs +++ b/benches/bench_rustfft_sse.rs @@ -76,6 +76,7 @@ fn bench_planned_multi_f64(b: &mut Bencher, len: usize) { #[bench] fn sse_butterfly32_17(b: &mut Bencher) { bench_planned_multi_f32(b, 17);} #[bench] fn sse_butterfly32_19(b: &mut Bencher) { bench_planned_multi_f32(b, 19);} #[bench] fn sse_butterfly32_23(b: &mut Bencher) { bench_planned_multi_f32(b, 23);} +#[bench] fn sse_butterfly32_24(b: &mut Bencher) { bench_planned_multi_f32(b, 24);} #[bench] fn sse_butterfly32_29(b: &mut Bencher) { bench_planned_multi_f32(b, 29);} #[bench] fn sse_butterfly32_31(b: &mut Bencher) { bench_planned_multi_f32(b, 31);} #[bench] fn sse_butterfly32_32(b: &mut Bencher) { bench_planned_multi_f32(b, 32);} @@ -97,6 +98,7 @@ fn bench_planned_multi_f64(b: &mut Bencher, len: usize) { #[bench] fn sse_butterfly64_17(b: &mut Bencher) { bench_planned_multi_f64(b, 17);} #[bench] fn sse_butterfly64_19(b: &mut Bencher) { bench_planned_multi_f64(b, 19);} #[bench] fn sse_butterfly64_23(b: &mut Bencher) { bench_planned_multi_f64(b, 23);} +#[bench] fn sse_butterfly64_24(b: &mut Bencher) { bench_planned_multi_f64(b, 24);} #[bench] fn sse_butterfly64_29(b: &mut Bencher) { bench_planned_multi_f64(b, 29);} #[bench] fn sse_butterfly64_31(b: &mut Bencher) { bench_planned_multi_f64(b, 31);} #[bench] fn sse_butterfly64_32(b: &mut Bencher) { bench_planned_multi_f64(b, 32);} diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index a160fd66..3b1bb3ce 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -12,7 +12,17 @@ use crate::{Direction, Fft, Length}; use super::sse_common::{assert_f32, assert_f64}; use super::sse_utils::*; -use super::sse_vector::{SseArrayMut, SseVector}; +use super::sse_vector::{SseArray, SseArrayMut, SseVector}; + +#[inline(always)] +unsafe fn pack_32(a: Complex, b: Complex) -> __m128 { + [a,b].as_slice().load_complex(0) +} +#[inline(always)] +unsafe fn pack_64(a: Complex) -> __m128d { + [a].as_slice().load_complex(0) +} + #[allow(unused)] macro_rules! boilerplate_fft_sse_f32_butterfly { @@ -83,6 +93,46 @@ macro_rules! boilerplate_fft_sse_f32_butterfly { }; } +macro_rules! boilerplate_fft_sse_f32_butterfly_noparallel { + ($struct_name:ident, $len:expr, $direction_fn:expr) => { + impl $struct_name { + // Do a single fft + #[target_feature(enable = "sse4.1")] + pub(crate) unsafe fn perform_fft_butterfly(&self, buffer: &mut [Complex]) { + self.perform_fft_contiguous(workaround_transmute_mut::<_, Complex>(buffer)); + } + + // Do multiple ffts over a longer vector inplace, called from "process_with_scratch" of Fft trait + #[target_feature(enable = "sse4.1")] + pub(crate) unsafe fn perform_fft_butterfly_multi( + &self, + buffer: &mut [Complex], + ) -> Result<(), ()> { + array_utils::iter_chunks(buffer, self.len(), |chunk| { + self.perform_fft_butterfly(chunk) + }) + } + + // Do multiple ffts over a longer vector outofplace, called from "process_outofplace_with_scratch" of Fft trait + #[target_feature(enable = "sse4.1")] + pub(crate) unsafe fn perform_oop_fft_butterfly_multi( + &self, + input: &mut [Complex], + output: &mut [Complex], + ) -> Result<(), ()> { + array_utils::iter_chunks_zipped(input, output, self.len(), |in_chunk, out_chunk| { + let input_slice = workaround_transmute_mut(in_chunk); + let output_slice = workaround_transmute_mut(out_chunk); + self.perform_fft_contiguous(DoubleBuf { + input: input_slice, + output: output_slice, + }) + }) + } + } + }; +} + macro_rules! boilerplate_fft_sse_f64_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { impl $struct_name { @@ -645,7 +695,7 @@ impl SseF32Butterfly4 { let [value0ab, value1ab] = transpose_complex_2x2_f32(value01a, value01b); let [value2ab, value3ab] = transpose_complex_2x2_f32(value23a, value23b); - let out = self.perform_parallel_fft_direct(value0ab, value1ab, value2ab, value3ab); + let out = self.perform_parallel_fft_direct([value0ab, value1ab, value2ab, value3ab]); let [out0, out1] = transpose_complex_2x2_f32(out[0], out[1]); let [out2, out3] = transpose_complex_2x2_f32(out[2], out[3]); @@ -684,21 +734,15 @@ impl SseF32Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct( - &self, - values0: __m128, - values1: __m128, - values2: __m128, - values3: __m128, - ) -> [__m128; 4] { + pub(crate) unsafe fn perform_parallel_fft_direct(&self, values: [__m128; 4]) -> [__m128; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = parallel_fft2_interleaved_f32(values0, values2); - let mut temp1 = parallel_fft2_interleaved_f32(values1, values3); + let temp0 = parallel_fft2_interleaved_f32(values[0], values[2]); + let mut temp1 = parallel_fft2_interleaved_f32(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate_both(temp1[1]); @@ -755,7 +799,7 @@ impl SseF64Butterfly4 { let value2 = buffer.load_complex(2); let value3 = buffer.load_complex(3); - let out = self.perform_fft_direct(value0, value1, value2, value3); + let out = self.perform_fft_direct([value0, value1, value2, value3]); buffer.store_complex(out[0], 0); buffer.store_complex(out[1], 1); @@ -764,21 +808,15 @@ impl SseF64Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: __m128d, - value1: __m128d, - value2: __m128d, - value3: __m128d, - ) -> [__m128d; 4] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [__m128d; 4]) -> [__m128d; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = solo_fft2_f64(value0, value2); - let mut temp1 = solo_fft2_f64(value1, value3); + let temp0 = solo_fft2_f64(values[0], values[2]); + let mut temp1 = solo_fft2_f64(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate(temp1[1]); @@ -1226,7 +1264,7 @@ impl SseF64Butterfly6 { let value4 = buffer.load_complex(4); let value5 = buffer.load_complex(5); - let out = self.perform_fft_direct(value0, value1, value2, value3, value4, value5); + let out = self.perform_fft_direct([value0, value1, value2, value3, value4, value5]); buffer.store_complex(out[0], 0); buffer.store_complex(out[1], 1); @@ -1237,20 +1275,12 @@ impl SseF64Butterfly6 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: __m128d, - value1: __m128d, - value2: __m128d, - value3: __m128d, - value4: __m128d, - value5: __m128d, - ) -> [__m128d; 6] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [__m128d; 6]) -> [__m128d; 6] { // Algorithm: 3x2 good-thomas // Size-3 FFTs down the columns of our reordered array - let mid0 = self.bf3.perform_fft_direct(value0, value2, value4); - let mid1 = self.bf3.perform_fft_direct(value3, value5, value1); + let mid0 = self.bf3.perform_fft_direct(values[0], values[2], values[4]); + let mid1 = self.bf3.perform_fft_direct(values[3], values[5], values[1]); // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors @@ -1366,10 +1396,10 @@ impl SseF32Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_parallel_fft_direct(values[0], values[2], values[4], values[6]); + .perform_parallel_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_parallel_fft_direct(values[1], values[3], values[5], values[7]); + .perform_parallel_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors let val5b = self.rotate90.rotate_both(val47[1]); @@ -1403,32 +1433,19 @@ impl SseF32Butterfly8 { // pub struct SseF64Butterfly8 { - root2: __m128d, - direction: FftDirection, bf4: SseF64Butterfly4, - rotate90: Rotate90F64, } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this - .direction); + .bf4.direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this - .direction); + .bf4.direction); impl SseF64Butterfly8 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf4 = SseF64Butterfly4::new(direction); - let root2 = unsafe { _mm_load1_pd(&0.5f64.sqrt()) }; - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; Self { - root2, - direction, - bf4, - rotate90, + bf4: SseF64Butterfly4::new(direction), } } @@ -1449,19 +1466,15 @@ impl SseF64Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_fft_direct(values[0], values[2], values[4], values[6]); + .perform_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_fft_direct(values[1], values[3], values[5], values[7]); + .perform_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors - let val5b = self.rotate90.rotate(val47[1]); - let val5c = _mm_add_pd(val5b, val47[1]); - val47[1] = _mm_mul_pd(val5c, self.root2); - val47[2] = self.rotate90.rotate(val47[2]); - let val7b = self.rotate90.rotate(val47[3]); - let val7c = _mm_sub_pd(val7b, val47[3]); - val47[3] = _mm_mul_pd(val7c, self.root2); + val47[1] = self.bf4.rotate.rotate_45(val47[1]); + val47[2] = self.bf4.rotate.rotate(val47[2]); + val47[3] = self.bf4.rotate.rotate_135(val47[3]); // step 4: transpose -- skipped because we're going to do the next FFTs non-contiguously @@ -1957,13 +1970,13 @@ impl SseF32Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_parallel_fft_direct(values[0], values[3], values[6], values[9]); + .perform_parallel_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_parallel_fft_direct(values[4], values[7], values[10], values[1]); + .perform_parallel_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_parallel_fft_direct(values[8], values[11], values[2], values[5]); + .perform_parallel_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2037,13 +2050,13 @@ impl SseF64Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_fft_direct(values[0], values[3], values[6], values[9]); + .perform_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_fft_direct(values[4], values[7], values[10], values[1]); + .perform_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_fft_direct(values[8], values[11], values[2], values[5]); + .perform_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2271,199 +2284,141 @@ impl SseF64Butterfly15 { // pub struct SseF32Butterfly16 { - direction: FftDirection, bf4: SseF32Butterfly4, - bf8: SseF32Butterfly8, - rotate90: Rotate90F32, - twiddle01: __m128, - twiddle23: __m128, - twiddle01conj: __m128, - twiddle23conj: __m128, + twiddles_packed: [__m128; 6], twiddle1: __m128, twiddle2: __m128, twiddle3: __m128, - twiddle1c: __m128, - twiddle2c: __m128, - twiddle3c: __m128, + twiddle6: __m128, + twiddle9: __m128, } -boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this - .direction); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this +boilerplate_fft_sse_f32_butterfly_noparallel!( + SseF32Butterfly16, + 16, + |this: &SseF32Butterfly16<_>| this.bf4.direction +); +boilerplate_fft_sse_common_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this.bf4 .direction); impl SseF32Butterfly16 { - #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = SseF32Butterfly8::new(direction); - let bf4 = SseF32Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); let tw2: Complex = twiddles::compute_twiddle(2, 16, direction); let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); - let twiddle01 = unsafe { _mm_set_ps(tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23 = unsafe { _mm_set_ps(tw3.im, tw3.re, tw2.im, tw2.re) }; - let twiddle01conj = unsafe { _mm_set_ps(-tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23conj = unsafe { _mm_set_ps(-tw3.im, tw3.re, -tw2.im, tw2.re) }; - let twiddle1 = unsafe { _mm_set_ps(tw1.im, tw1.re, tw1.im, tw1.re) }; - let twiddle2 = unsafe { _mm_set_ps(tw2.im, tw2.re, tw2.im, tw2.re) }; - let twiddle3 = unsafe { _mm_set_ps(tw3.im, tw3.re, tw3.im, tw3.re) }; - let twiddle1c = unsafe { _mm_set_ps(-tw1.im, tw1.re, -tw1.im, tw1.re) }; - let twiddle2c = unsafe { _mm_set_ps(-tw2.im, tw2.re, -tw2.im, tw2.re) }; - let twiddle3c = unsafe { _mm_set_ps(-tw3.im, tw3.re, -tw3.im, tw3.re) }; - Self { - direction, - bf4, - bf8, - rotate90, - twiddle01, - twiddle23, - twiddle01conj, - twiddle23conj, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + let tw4: Complex = twiddles::compute_twiddle(4, 16, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); + + unsafe { + Self { + bf4: SseF32Butterfly4::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle6: pack_32(tw6, tw6), + twiddle9: pack_32(tw9, tw9), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14 }); + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ]; - let out = self.perform_fft_direct(input_packed); + let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = SseVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = SseVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = SseVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + // cross FFTs + let mut store = |i: usize, vectors: [__m128; 4]| { + buffer.store_complex(vectors[0], i + 0); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + }; + let out0 = self.bf4.perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); + store(0, out0); - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7}); + let out1 = self.bf4.perform_parallel_fft_direct([mid4, mid5, mid6, mid7]); + store(2, out1); } - #[inline(always)] + // benchmarking shows it's faster to always use the nonparallel version, but this is kep around for reference + #[allow(unused)] pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - let values = interleave_complex_f32!(input_packed, 8, {0, 1, 2, 3 ,4 ,5 ,6 ,7}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11,12,13,14, 15}); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128; 8]) -> [__m128; 8] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1503 = extract_hi_hi_f32(input[7], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - - let in_evens = [in0002, in0406, in0810, in1214]; - - // step 2: column FFTs - let evens = self.bf8.perform_fft_direct(in_evens); - let mut odds1 = self.bf4.perform_fft_direct(in0105, in0913); - let mut odds3 = self.bf4.perform_fft_direct(in1503, in0711); - - // step 3: apply twiddle factors - odds1[0] = SseVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = SseVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle23conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp0[1]), - _mm_add_ps(evens[3], temp1[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp0[1]), - _mm_sub_ps(evens[3], temp1[1]), - ] - } - - #[inline(always)] - unsafe fn perform_parallel_fft_direct(&self, input: [__m128; 16]) -> [__m128; 16] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch + // we're going to hardcode a step of 4x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) // and // step 2: column FFTs - let evens = self.bf8.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self - .bf4 - .perform_parallel_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self - .bf4 - .perform_parallel_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = SseVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = SseVector::mul_complex(odds3[3], self.twiddle3c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); + // and + // step 3: twiddle factors + let load = |i: usize| { + let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 16)); + let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 4), buffer.load_complex(i + 20)); + let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 24)); + let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 28)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + }; - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); + let [in2, in3] = load(2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); + tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddle6); + tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); + + let [in0, in1] = load(0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, values_a: [__m128; 4], values_b: [__m128; 4]| { + for n in 0..4 { + let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); + buffer.store_complex(a, i + n*4); + buffer.store_complex(b, i + n*4 + 16); + } + }; + let out0 = self.bf4.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + let out1 = self.bf4.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + store(0, out0, out1); - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp2[0]), - _mm_add_ps(evens[3], temp3[0]), - _mm_add_ps(evens[4], temp0[1]), - _mm_add_ps(evens[5], temp1[1]), - _mm_add_ps(evens[6], temp2[1]), - _mm_add_ps(evens[7], temp3[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp2[0]), - _mm_sub_ps(evens[3], temp3[0]), - _mm_sub_ps(evens[4], temp0[1]), - _mm_sub_ps(evens[5], temp1[1]), - _mm_sub_ps(evens[6], temp2[1]), - _mm_sub_ps(evens[7], temp3[1]), - ] + let out2 = self.bf4.perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + let out3 = self.bf4.perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + store(2, out2, out3); } } - // _ __ __ _ _ _ _ _ // / |/ /_ / /_ | || | | |__ (_) |_ // | | '_ \ _____ | '_ \| || |_| '_ \| | __| @@ -2472,131 +2427,85 @@ impl SseF32Butterfly16 { // pub struct SseF64Butterfly16 { - direction: FftDirection, bf4: SseF64Butterfly4, - bf8: SseF64Butterfly8, - rotate90: Rotate90F64, twiddle1: __m128d, - twiddle2: __m128d, twiddle3: __m128d, - twiddle1c: __m128d, - twiddle2c: __m128d, - twiddle3c: __m128d, + twiddle9: __m128d, } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this - .direction); + .bf4.direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this - .direction); + .bf4.direction); impl SseF64Butterfly16 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = SseF64Butterfly8::new(direction); - let bf4 = SseF64Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(1, 16, direction).re as *const f64) }; - let twiddle2 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(2, 16, direction).re as *const f64) }; - let twiddle3 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(3, 16, direction).re as *const f64) }; - let twiddle1c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(1, 16, direction).conj().re as *const f64) - }; - let twiddle2c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(2, 16, direction).conj().re as *const f64) - }; - let twiddle3c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(3, 16, direction).conj().re as *const f64) - }; - - Self { - direction, - bf4, - bf8, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); + + unsafe { + Self { + bf4: SseF64Butterfly4::new(direction), + twiddle1: pack_64(tw1), + twiddle3: pack_64(tw3), + twiddle9: pack_64(tw9), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let values = - read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128d; 16]) -> [__m128d; 16] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch + // we're going to hardcode a step of 4x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) // and // step 2: column FFTs - let evens = self.bf8.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self - .bf4 - .perform_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self - .bf4 - .perform_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); + // and + // step 3: twiddle factors + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ]; - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = self.bf4.rotate.rotate_45(tmp1[2]); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = self.bf4.rotate.rotate_135(tmp3[2]); + tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = self.bf4.rotate.rotate_45(tmp2[1]); + tmp2[2] = self.bf4.rotate.rotate(tmp2[2]); + tmp2[3] = self.bf4.rotate.rotate_135(tmp2[3]); + + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i: usize, vectors: [__m128d; 4]| { + buffer.store_complex(vectors[0], i + 0); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + }; - odds1[3] = SseVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = SseVector::mul_complex(odds3[3], self.twiddle3c); + let out0 = self.bf4.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + store(0, out0); - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); + let out1 = self.bf4.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + store(1, out1); - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); + let out2 = self.bf4.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + store(2, out2); - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_pd(evens[0], temp0[0]), - _mm_add_pd(evens[1], temp1[0]), - _mm_add_pd(evens[2], temp2[0]), - _mm_add_pd(evens[3], temp3[0]), - _mm_add_pd(evens[4], temp0[1]), - _mm_add_pd(evens[5], temp1[1]), - _mm_add_pd(evens[6], temp2[1]), - _mm_add_pd(evens[7], temp3[1]), - _mm_sub_pd(evens[0], temp0[0]), - _mm_sub_pd(evens[1], temp1[0]), - _mm_sub_pd(evens[2], temp2[0]), - _mm_sub_pd(evens[3], temp3[0]), - _mm_sub_pd(evens[4], temp0[1]), - _mm_sub_pd(evens[5], temp1[1]), - _mm_sub_pd(evens[6], temp2[1]), - _mm_sub_pd(evens[7], temp3[1]), - ] + let out3 = self.bf4.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + store(3, out3); } } @@ -2608,261 +2517,169 @@ impl SseF64Butterfly16 { // pub struct SseF32Butterfly24 { - direction: FftDirection, + bf4: SseF32Butterfly4, bf6: SseF32Butterfly6, - bf12: SseF32Butterfly12, - rotate90: Rotate90F32, - twiddle01: __m128, - twiddle23: __m128, - twiddle45: __m128, - twiddle01conj: __m128, - twiddle23conj: __m128, - twiddle45conj: __m128, + twiddles_packed: [__m128; 9], twiddle1: __m128, twiddle2: __m128, twiddle4: __m128, twiddle5: __m128, - twiddle1c: __m128, - twiddle2c: __m128, - twiddle4c: __m128, - twiddle5c: __m128, + twiddle8: __m128, + twiddle10: __m128, } boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfly24<_>| { - this.direction + this.bf4.direction }); boilerplate_fft_sse_common_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfly24<_>| this - .direction); + .bf4.direction); impl SseF32Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf6 = SseF32Butterfly6::new(direction); - let bf12 = SseF32Butterfly12::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); let tw3: Complex = twiddles::compute_twiddle(3, 24, direction); let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); - let twiddle01 = unsafe { _mm_set_ps(tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23 = unsafe { _mm_set_ps(tw3.im, tw3.re, tw2.im, tw2.re) }; - let twiddle45 = unsafe { _mm_set_ps(tw5.im, tw5.re, tw4.im, tw4.re) }; - let twiddle01conj = unsafe { _mm_set_ps(-tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23conj = unsafe { _mm_set_ps(-tw3.im, tw3.re, -tw2.im, tw2.re) }; - let twiddle45conj = unsafe { _mm_set_ps(-tw5.im, tw5.re, -tw4.im, tw4.re) }; - let twiddle1 = unsafe { _mm_set_ps(tw1.im, tw1.re, tw1.im, tw1.re) }; - let twiddle2 = unsafe { _mm_set_ps(tw2.im, tw2.re, tw2.im, tw2.re) }; - let twiddle4 = unsafe { _mm_set_ps(tw4.im, tw4.re, tw4.im, tw4.re) }; - let twiddle5 = unsafe { _mm_set_ps(tw5.im, tw5.re, tw5.im, tw5.re) }; - let twiddle1c = unsafe { _mm_set_ps(-tw1.im, tw1.re, -tw1.im, tw1.re) }; - let twiddle2c = unsafe { _mm_set_ps(-tw2.im, tw2.re, -tw2.im, tw2.re) }; - let twiddle4c = unsafe { _mm_set_ps(-tw4.im, tw4.re, -tw4.im, tw4.re) }; - let twiddle5c = unsafe { _mm_set_ps(-tw5.im, tw5.re, -tw5.im, tw5.re) }; - Self { - direction, - bf6, - bf12, - rotate90, - twiddle01, - twiddle23, - twiddle45, - twiddle01conj, - twiddle23conj, - twiddle45conj, - twiddle1, - twiddle2, - twiddle4, - twiddle5, - twiddle1c, - twiddle2c, - twiddle4c, - twiddle5c, + let tw6: Complex = twiddles::compute_twiddle(6, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 24, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 24, direction); + unsafe { + Self { + bf4: SseF32Butterfly4::new(direction), + bf6: SseF32Butterfly6::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle4: pack_32(tw4, tw4), + twiddle5: pack_32(tw5, tw5), + twiddle8: pack_32(tw8, tw8), + twiddle10: pack_32(tw10, tw10), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = - read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11}); - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46}); - - let values = - interleave_complex_f32!(input_packed, 12, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = - separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128; 12]) -> [__m128; 12] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - - let in2303 = extract_hi_hi_f32(input[11], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - - let in_evens = [in0002, in0406, in0810, in1214, in1618, in2022]; - + // we're going to hardcode a step of 8x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) + // and // step 2: column FFTs - let evens = self.bf12.perform_fft_direct(in_evens); - let mut odds1 = self.bf6.perform_fft_direct(in0105, in0913, in1721); - let mut odds3 = self.bf6.perform_fft_direct(in2303, in0711, in1519); - - // step 3: apply twiddle factors - odds1[0] = SseVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = SseVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle23conj); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle45); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle45conj); + // and + // step 3: twiddle factors + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ]; - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self.bf4.perform_parallel_fft_direct(load(4)); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = SseVector::mul_complex(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = SseVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = SseVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = SseVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors: [__m128; 6]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + }; - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); + let out0 = self.bf6.perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); + store(0, out0); - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp2[0]), - _mm_add_ps(evens[3], temp0[1]), - _mm_add_ps(evens[4], temp1[1]), - _mm_add_ps(evens[5], temp2[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp2[0]), - _mm_sub_ps(evens[3], temp0[1]), - _mm_sub_ps(evens[4], temp1[1]), - _mm_sub_ps(evens[5], temp2[1]), - ] + let out1 = self.bf6.perform_parallel_fft_direct(mid6, mid7, mid8, mid9, mid10, mid11); + store(2, out1); } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct(&self, input: [__m128; 24]) -> [__m128; 24] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf12.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_parallel_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_parallel_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); - - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = _mm_add_ps(vec, rotated); - _mm_mul_ps(sum, _mm_set1_ps(0.5f32.sqrt())) + pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + let load = |i: usize| { + let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 24)); + let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 6), buffer.load_complex(i + 30)); + let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 36)); + let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 18), buffer.load_complex(i + 42)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = _mm_sub_ps(vec, rotated); - _mm_mul_ps(sum, _mm_set1_ps(0.5f32.sqrt())) + + let [in0, in1] = load(0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_both_45(tmp1[3]); + + let [in2, in3] = load(2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = SseVector::mul_complex(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate_both(tmp2[3]); + tmp3[1] = self.bf4.rotate.rotate_both_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate_both(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_both_135(tmp3[3]); + + let [in4, in5] = load(4); + let mut tmp4 = self.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = SseVector::mul_complex(tmp4[1], self.twiddle4); + tmp4[2] = SseVector::mul_complex(tmp4[2], self.twiddle8); + tmp4[3] = SseVector::neg(tmp4[3]); + tmp5[1] = SseVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = SseVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_both_225(tmp5[3]); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors_a: [__m128; 6], vectors_b: [__m128; 6]| { + for n in 0..6 { + let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); + buffer.store_complex(a, i + n*4); + buffer.store_complex(b, i + n*4 + 24); + } }; - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = SseVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = SseVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = SseVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = SseVector::mul_complex(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp2[0]), - _mm_add_ps(evens[3], temp3[0]), - _mm_add_ps(evens[4], temp4[0]), - _mm_add_ps(evens[5], temp5[0]), - _mm_add_ps(evens[6], temp0[1]), - _mm_add_ps(evens[7], temp1[1]), - _mm_add_ps(evens[8], temp2[1]), - _mm_add_ps(evens[9], temp3[1]), - _mm_add_ps(evens[10], temp4[1]), - _mm_add_ps(evens[11], temp5[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp2[0]), - _mm_sub_ps(evens[3], temp3[0]), - _mm_sub_ps(evens[4], temp4[0]), - _mm_sub_ps(evens[5], temp5[0]), - _mm_sub_ps(evens[6], temp0[1]), - _mm_sub_ps(evens[7], temp1[1]), - _mm_sub_ps(evens[8], temp2[1]), - _mm_sub_ps(evens[9], temp3[1]), - _mm_sub_ps(evens[10], temp4[1]), - _mm_sub_ps(evens[11], temp5[1]), - ] + let out0 = self.bf6.perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); + let out1 = self.bf6.perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); + store(0, out0, out1); + + let out2 = self.bf6.perform_parallel_fft_direct(tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]); + let out3 = self.bf6.perform_parallel_fft_direct(tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]); + store(2, out2, out3); } } @@ -2874,170 +2691,109 @@ impl SseF32Butterfly24 { // pub struct SseF64Butterfly24 { - direction: FftDirection, + bf4: SseF64Butterfly4, bf6: SseF64Butterfly6, - bf12: SseF64Butterfly12, - rotate90: Rotate90F64, twiddle1: __m128d, twiddle2: __m128d, twiddle4: __m128d, twiddle5: __m128d, - twiddle1c: __m128d, - twiddle2c: __m128d, - twiddle4c: __m128d, - twiddle5c: __m128d, + twiddle8: __m128d, + twiddle10: __m128d, } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfly24<_>| { - this.direction + this.bf4.direction }); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfly24<_>| this - .direction); + .bf4.direction); impl SseF64Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf6 = SseF64Butterfly6::new(direction); - let bf12 = SseF64Butterfly12::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(1, 24, direction).re as *const f64) }; - let twiddle2 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(2, 24, direction).re as *const f64) }; - let twiddle4 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(4, 24, direction).re as *const f64) }; - let twiddle5 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(5, 24, direction).re as *const f64) }; - let twiddle1c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(1, 24, direction).conj().re as *const f64) - }; - let twiddle2c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(2, 24, direction).conj().re as *const f64) - }; - let twiddle4c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(4, 24, direction).conj().re as *const f64) - }; - let twiddle5c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(5, 24, direction).conj().re as *const f64) - }; - Self { - direction, - bf6, - bf12, - rotate90, - twiddle1, - twiddle2, - twiddle4, - twiddle5, - twiddle1c, - twiddle2c, - twiddle4c, - twiddle5c, + let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); + let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + + unsafe { + Self { + bf4: SseF64Butterfly4::new(direction), + bf6: SseF64Butterfly6::new(direction), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle4: pack_64(tw4), + twiddle5: pack_64(tw5), + twiddle8: pack_64(tw8), + twiddle10: pack_64(tw10), + } } } #[inline(always)] - pub(crate) unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128d; 24]) -> [__m128d; 24] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch + unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + // we're going to hardcode a step of 8x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) // and // step 2: column FFTs - let evens = self.bf12.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); + // and + // step 3: twiddle factors + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ]; - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = _mm_add_pd(vec, rotated); - _mm_mul_pd(sum, _mm_set1_pd(0.5f64.sqrt())) - }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = _mm_sub_pd(vec, rotated); - _mm_mul_pd(sum, _mm_set1_pd(0.5f64.sqrt())) + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_45(tmp1[3]); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = SseVector::mul_complex(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate(tmp2[3]); + + let mut tmp4 = self.bf4.perform_fft_direct(load(4)); + tmp4[1] = SseVector::mul_complex(tmp4[1], self.twiddle4); + tmp4[2] = SseVector::mul_complex(tmp4[2], self.twiddle8); + tmp4[3] = SseVector::neg(tmp4[3]); + + let mut tmp5 = self.bf4.perform_fft_direct(load(5)); + tmp5[1] = SseVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = SseVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_225(tmp5[3]); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = self.bf4.rotate.rotate_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_135(tmp3[3]); + + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors: [__m128d; 6]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); }; - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = SseVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = SseVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = SseVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = SseVector::mul_complex(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_pd(evens[0], temp0[0]), - _mm_add_pd(evens[1], temp1[0]), - _mm_add_pd(evens[2], temp2[0]), - _mm_add_pd(evens[3], temp3[0]), - _mm_add_pd(evens[4], temp4[0]), - _mm_add_pd(evens[5], temp5[0]), - _mm_add_pd(evens[6], temp0[1]), - _mm_add_pd(evens[7], temp1[1]), - _mm_add_pd(evens[8], temp2[1]), - _mm_add_pd(evens[9], temp3[1]), - _mm_add_pd(evens[10], temp4[1]), - _mm_add_pd(evens[11], temp5[1]), - _mm_sub_pd(evens[0], temp0[0]), - _mm_sub_pd(evens[1], temp1[0]), - _mm_sub_pd(evens[2], temp2[0]), - _mm_sub_pd(evens[3], temp3[0]), - _mm_sub_pd(evens[4], temp4[0]), - _mm_sub_pd(evens[5], temp5[0]), - _mm_sub_pd(evens[6], temp0[1]), - _mm_sub_pd(evens[7], temp1[1]), - _mm_sub_pd(evens[8], temp2[1]), - _mm_sub_pd(evens[9], temp3[1]), - _mm_sub_pd(evens[10], temp4[1]), - _mm_sub_pd(evens[11], temp5[1]), - ] + let out0 = self.bf6.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); + store(0, out0); + + let out1 = self.bf6.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]]); + store(1, out1); + + let out2 = self.bf6.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]]); + store(2, out2); + + let out3 = self.bf6.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]]); + store(3, out3); } } @@ -3049,49 +2805,31 @@ impl SseF64Butterfly24 { // pub struct SseF32Butterfly32 { - direction: FftDirection, bf8: SseF32Butterfly8, - bf16: SseF32Butterfly16, - rotate90: Rotate90F32, - twiddle01: __m128, - twiddle23: __m128, - twiddle45: __m128, - twiddle67: __m128, - twiddle01conj: __m128, - twiddle23conj: __m128, - twiddle45conj: __m128, - twiddle67conj: __m128, + twiddles_packed: [__m128; 12], twiddle1: __m128, twiddle2: __m128, twiddle3: __m128, - twiddle4: __m128, twiddle5: __m128, twiddle6: __m128, twiddle7: __m128, - twiddle1c: __m128, - twiddle2c: __m128, - twiddle3c: __m128, - twiddle4c: __m128, - twiddle5c: __m128, - twiddle6c: __m128, - twiddle7c: __m128, + twiddle9: __m128, + twiddle10: __m128, + twiddle14: __m128, + twiddle15: __m128, + twiddle18: __m128, + twiddle21: __m128, } boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this - .direction); + .bf8.bf4.direction); boilerplate_fft_sse_common_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this - .direction); + .bf8.bf4.direction); impl SseF32Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = SseF32Butterfly8::new(direction); - let bf16 = SseF32Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); @@ -3099,258 +2837,172 @@ impl SseF32Butterfly32 { let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); - let twiddle01 = unsafe { _mm_set_ps(tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23 = unsafe { _mm_set_ps(tw3.im, tw3.re, tw2.im, tw2.re) }; - let twiddle45 = unsafe { _mm_set_ps(tw5.im, tw5.re, tw4.im, tw4.re) }; - let twiddle67 = unsafe { _mm_set_ps(tw7.im, tw7.re, tw6.im, tw6.re) }; - let twiddle01conj = unsafe { _mm_set_ps(-tw1.im, tw1.re, 0.0, 1.0) }; - let twiddle23conj = unsafe { _mm_set_ps(-tw3.im, tw3.re, -tw2.im, tw2.re) }; - let twiddle45conj = unsafe { _mm_set_ps(-tw5.im, tw5.re, -tw4.im, tw4.re) }; - let twiddle67conj = unsafe { _mm_set_ps(-tw7.im, tw7.re, -tw6.im, tw6.re) }; - let twiddle1 = unsafe { _mm_set_ps(tw1.im, tw1.re, tw1.im, tw1.re) }; - let twiddle2 = unsafe { _mm_set_ps(tw2.im, tw2.re, tw2.im, tw2.re) }; - let twiddle3 = unsafe { _mm_set_ps(tw3.im, tw3.re, tw3.im, tw3.re) }; - let twiddle4 = unsafe { _mm_set_ps(tw4.im, tw4.re, tw4.im, tw4.re) }; - let twiddle5 = unsafe { _mm_set_ps(tw5.im, tw5.re, tw5.im, tw5.re) }; - let twiddle6 = unsafe { _mm_set_ps(tw6.im, tw6.re, tw6.im, tw6.re) }; - let twiddle7 = unsafe { _mm_set_ps(tw7.im, tw7.re, tw7.im, tw7.re) }; - let twiddle1c = unsafe { _mm_set_ps(-tw1.im, tw1.re, -tw1.im, tw1.re) }; - let twiddle2c = unsafe { _mm_set_ps(-tw2.im, tw2.re, -tw2.im, tw2.re) }; - let twiddle3c = unsafe { _mm_set_ps(-tw3.im, tw3.re, -tw3.im, tw3.re) }; - let twiddle4c = unsafe { _mm_set_ps(-tw4.im, tw4.re, -tw4.im, tw4.re) }; - let twiddle5c = unsafe { _mm_set_ps(-tw5.im, tw5.re, -tw5.im, tw5.re) }; - let twiddle6c = unsafe { _mm_set_ps(-tw6.im, tw6.re, -tw6.im, tw6.re) }; - let twiddle7c = unsafe { _mm_set_ps(-tw7.im, tw7.re, -tw7.im, tw7.re) }; - Self { - direction, - bf8, - bf16, - rotate90, - twiddle01, - twiddle23, - twiddle45, - twiddle67, - twiddle01conj, - twiddle23conj, - twiddle45conj, - twiddle67conj, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + let tw8: Complex = twiddles::compute_twiddle(8, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); + unsafe { + Self { + bf8: SseF32Butterfly8::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + pack_32(tw6, tw7), + pack_32(tw12, tw14), + pack_32(tw18, tw21), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle5: pack_32(tw5, tw5), + twiddle6: pack_32(tw6, tw6), + twiddle7: pack_32(tw7, tw7), + twiddle9: pack_32(tw9, tw9), + twiddle10: pack_32(tw10, tw10), + twiddle14: pack_32(tw14, tw14), + twiddle15: pack_32(tw15, tw15), + twiddle18: pack_32(tw18, tw18), + twiddle21: pack_32(tw21, tw21), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 }); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}); - - let values = interleave_complex_f32!(input_packed, 16, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128; 16]) -> [__m128; 16] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - let in2426 = extract_lo_lo_f32(input[12], input[13]); - let in2830 = extract_lo_lo_f32(input[14], input[15]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - let in2529 = extract_hi_hi_f32(input[12], input[14]); - - let in3103 = extract_hi_hi_f32(input[15], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - let in2327 = extract_hi_hi_f32(input[11], input[13]); - - let in_evens = [ - in0002, in0406, in0810, in1214, in1618, in2022, in2426, in2830, - ]; - + // we're going to hardcode a step of 8x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) + // and // step 2: column FFTs - let evens = self.bf16.perform_fft_direct(in_evens); - let mut odds1 = self - .bf8 - .perform_fft_direct([in0105, in0913, in1721, in2529]); - let mut odds3 = self - .bf8 - .perform_fft_direct([in3103, in0711, in1519, in2327]); - - // step 3: apply twiddle factors - odds1[0] = SseVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = SseVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle23conj); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle45); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle45conj); - - odds1[3] = SseVector::mul_complex(odds1[3], self.twiddle67); - odds3[3] = SseVector::mul_complex(odds3[3], self.twiddle67conj); + // and + // step 3: twiddle factors + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ]; - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); + let mut tmp0 = self.bf8.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = SseVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = SseVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = SseVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self.bf8.bf4.perform_parallel_fft_direct(load(4)); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = SseVector::mul_complex(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid12, mid13] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp3 = self.bf8.bf4.perform_parallel_fft_direct(load(6)); + tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddles_packed[9]); + tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddles_packed[10]); + tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddles_packed[11]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp3[0], tmp3[1]); + let [mid14, mid15] = transpose_complex_2x2_f32(tmp3[2], tmp3[3]); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors: [__m128; 8]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + buffer.store_complex(vectors[6], i + 24); + buffer.store_complex(vectors[7], i + 28); + }; - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); + let out0 = self.bf8.perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); + store(0, out0); - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp2[0]), - _mm_add_ps(evens[3], temp3[0]), - _mm_add_ps(evens[4], temp0[1]), - _mm_add_ps(evens[5], temp1[1]), - _mm_add_ps(evens[6], temp2[1]), - _mm_add_ps(evens[7], temp3[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp2[0]), - _mm_sub_ps(evens[3], temp3[0]), - _mm_sub_ps(evens[4], temp0[1]), - _mm_sub_ps(evens[5], temp1[1]), - _mm_sub_ps(evens[6], temp2[1]), - _mm_sub_ps(evens[7], temp3[1]), - ] + let out1 = self.bf8.perform_parallel_fft_direct([mid8, mid9, mid10, mid11, mid12, mid13, mid14, mid15]); + store(2, out1); } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct(&self, input: [__m128; 32]) -> [__m128; 32] { - // we're going to hardcode a step of split radix + pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + let load = |i: usize| { + let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 32)); + let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 40)); + let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 16), buffer.load_complex(i + 48)); + let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 24), buffer.load_complex(i + 56)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], - ]); - let mut odds1 = self.bf8.perform_parallel_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], - ]); - let mut odds3 = self.bf8.perform_parallel_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], - ]); + let [in0, in1] = load(0); + let tmp0 = self.bf8.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); + + let [in2, in3] = load(2); + let mut tmp2 = self.bf8.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf8.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_both_45(tmp2[2]); + tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddle6); + tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); + + let [in4, in5] = load(4); + let mut tmp4 = self.bf8.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf8.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = self.bf8.bf4.rotate.rotate_both_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate_both(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_both_135(tmp4[3]); + tmp5[1] = SseVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = SseVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = SseVector::mul_complex(tmp5[3], self.twiddle15); + + let [in6, in7] = load(6); + let mut tmp6 = self.bf8.bf4.perform_parallel_fft_direct(in6); + let mut tmp7 = self.bf8.bf4.perform_parallel_fft_direct(in7); + tmp6[1] = SseVector::mul_complex(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_both_135(tmp6[2]); + tmp6[3] = SseVector::mul_complex(tmp6[3], self.twiddle18); + tmp7[1] = SseVector::mul_complex(tmp7[1], self.twiddle7); + tmp7[2] = SseVector::mul_complex(tmp7[2], self.twiddle14); + tmp7[3] = SseVector::mul_complex(tmp7[3], self.twiddle21); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors_a: [__m128; 8], vectors_b: [__m128; 8]| { + for n in 0..8 { + let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); + buffer.store_complex(a, i + n*4); + buffer.store_complex(b, i + n*4 + 32); + } + }; - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = SseVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = SseVector::mul_complex(odds3[3], self.twiddle3c); - - odds1[4] = SseVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = SseVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = SseVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = SseVector::mul_complex(odds3[5], self.twiddle5c); - - odds1[6] = SseVector::mul_complex(odds1[6], self.twiddle6); - odds3[6] = SseVector::mul_complex(odds3[6], self.twiddle6c); - - odds1[7] = SseVector::mul_complex(odds1[7], self.twiddle7); - odds3[7] = SseVector::mul_complex(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - let mut temp6 = parallel_fft2_interleaved_f32(odds1[6], odds3[6]); - let mut temp7 = parallel_fft2_interleaved_f32(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - temp6[1] = self.rotate90.rotate_both(temp6[1]); - temp7[1] = self.rotate90.rotate_both(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_ps(evens[0], temp0[0]), - _mm_add_ps(evens[1], temp1[0]), - _mm_add_ps(evens[2], temp2[0]), - _mm_add_ps(evens[3], temp3[0]), - _mm_add_ps(evens[4], temp4[0]), - _mm_add_ps(evens[5], temp5[0]), - _mm_add_ps(evens[6], temp6[0]), - _mm_add_ps(evens[7], temp7[0]), - _mm_add_ps(evens[8], temp0[1]), - _mm_add_ps(evens[9], temp1[1]), - _mm_add_ps(evens[10], temp2[1]), - _mm_add_ps(evens[11], temp3[1]), - _mm_add_ps(evens[12], temp4[1]), - _mm_add_ps(evens[13], temp5[1]), - _mm_add_ps(evens[14], temp6[1]), - _mm_add_ps(evens[15], temp7[1]), - _mm_sub_ps(evens[0], temp0[0]), - _mm_sub_ps(evens[1], temp1[0]), - _mm_sub_ps(evens[2], temp2[0]), - _mm_sub_ps(evens[3], temp3[0]), - _mm_sub_ps(evens[4], temp4[0]), - _mm_sub_ps(evens[5], temp5[0]), - _mm_sub_ps(evens[6], temp6[0]), - _mm_sub_ps(evens[7], temp7[0]), - _mm_sub_ps(evens[8], temp0[1]), - _mm_sub_ps(evens[9], temp1[1]), - _mm_sub_ps(evens[10], temp2[1]), - _mm_sub_ps(evens[11], temp3[1]), - _mm_sub_ps(evens[12], temp4[1]), - _mm_sub_ps(evens[13], temp5[1]), - _mm_sub_ps(evens[14], temp6[1]), - _mm_sub_ps(evens[15], temp7[1]), - ] + let out0 = self.bf8.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); + let out1 = self.bf8.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1]]); + store(0, out0, out1); + + let out2 = self.bf8.perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2]]); + let out3 = self.bf8.perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3]]); + store(2, out2, out3); } } @@ -3362,211 +3014,143 @@ impl SseF32Butterfly32 { // pub struct SseF64Butterfly32 { - direction: FftDirection, bf8: SseF64Butterfly8, - bf16: SseF64Butterfly16, - rotate90: Rotate90F64, twiddle1: __m128d, twiddle2: __m128d, twiddle3: __m128d, - twiddle4: __m128d, twiddle5: __m128d, twiddle6: __m128d, twiddle7: __m128d, - twiddle1c: __m128d, - twiddle2c: __m128d, - twiddle3c: __m128d, - twiddle4c: __m128d, - twiddle5c: __m128d, - twiddle6c: __m128d, - twiddle7c: __m128d, + twiddle9: __m128d, + twiddle10: __m128d, + twiddle14: __m128d, + twiddle15: __m128d, + twiddle18: __m128d, + twiddle21: __m128d, } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this - .direction); + .bf8.bf4.direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this - .direction); + .bf8.bf4.direction); impl SseF64Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = SseF64Butterfly8::new(direction); - let bf16 = SseF64Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(1, 32, direction).re as *const f64) }; - let twiddle2 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(2, 32, direction).re as *const f64) }; - let twiddle3 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(3, 32, direction).re as *const f64) }; - let twiddle4 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(4, 32, direction).re as *const f64) }; - let twiddle5 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(5, 32, direction).re as *const f64) }; - let twiddle6 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(6, 32, direction).re as *const f64) }; - let twiddle7 = - unsafe { _mm_loadu_pd(&twiddles::compute_twiddle(7, 32, direction).re as *const f64) }; - let twiddle1c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(1, 32, direction).conj().re as *const f64) - }; - let twiddle2c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(2, 32, direction).conj().re as *const f64) - }; - let twiddle3c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(3, 32, direction).conj().re as *const f64) - }; - let twiddle4c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(4, 32, direction).conj().re as *const f64) - }; - let twiddle5c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(5, 32, direction).conj().re as *const f64) - }; - let twiddle6c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(6, 32, direction).conj().re as *const f64) - }; - let twiddle7c = unsafe { - _mm_loadu_pd(&twiddles::compute_twiddle(7, 32, direction).conj().re as *const f64) - }; - - Self { - direction, - bf8, - bf16, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); + let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); + + unsafe { + Self { + bf8: SseF64Butterfly8::new(direction), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle3: pack_64(tw3), + twiddle5: pack_64(tw5), + twiddle6: pack_64(tw6), + twiddle7: pack_64(tw7), + twiddle9: pack_64(tw9), + twiddle10: pack_64(tw10), + twiddle14: pack_64(tw14), + twiddle15: pack_64(tw15), + twiddle18: pack_64(tw18), + twiddle21: pack_64(tw21), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - - let out = self.perform_fft_direct(values); + // we're going to hardcode a step of 8x4 mixed radix + // step 1: transpose (skipped since the vectors mean our data is already in the correct format) + // and + // step 2: column FFTs + // and + // step 3: twiddle factors + let load = |i| [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ]; - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - } + let mut tmp1 = self.bf8.bf4.perform_fft_direct(load(1)); + tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); + + let mut tmp2 = self.bf8.bf4.perform_fft_direct(load(2)); + tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_45(tmp2[2]); + tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddle6); + + let mut tmp3 = self.bf8.bf4.perform_fft_direct(load(3)); + tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); + + let mut tmp5 = self.bf8.bf4.perform_fft_direct(load(5)); + tmp5[1] = SseVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = SseVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = SseVector::mul_complex(tmp5[3], self.twiddle15); + + let mut tmp6 = self.bf8.bf4.perform_fft_direct(load(6)); + tmp6[1] = SseVector::mul_complex(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_135(tmp6[2]); + tmp6[3] = SseVector::mul_complex(tmp6[3], self.twiddle18); + + let mut tmp7 = self.bf8.bf4.perform_fft_direct(load(7)); + tmp7[1] = SseVector::mul_complex(tmp7[1], self.twiddle7); + tmp7[2] = SseVector::mul_complex(tmp7[2], self.twiddle14); + tmp7[3] = SseVector::mul_complex(tmp7[3], self.twiddle21); + + let mut tmp4 = self.bf8.bf4.perform_fft_direct(load(4)); + tmp4[1] = self.bf8.bf4.rotate.rotate_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_135(tmp4[3]); + + let tmp0 = self.bf8.bf4.perform_fft_direct(load(0)); + + // step 4 and 5: transpose and cross FFTs + let mut store = |i, vectors: [__m128d; 8]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + buffer.store_complex(vectors[6], i + 24); + buffer.store_complex(vectors[7], i + 28); + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [__m128d; 32]) -> [__m128d; 32] { - // we're going to hardcode a step of split radix + let out0 = self.bf8.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); + store(0, out0); - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], - ]); - let mut odds1 = self.bf8.perform_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], - ]); - let mut odds3 = self.bf8.perform_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], - ]); + let out1 = self.bf8.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1]]); + store(1, out1); - // step 3: apply twiddle factors - odds1[1] = SseVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = SseVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = SseVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = SseVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = SseVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = SseVector::mul_complex(odds3[3], self.twiddle3c); - - odds1[4] = SseVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = SseVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = SseVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = SseVector::mul_complex(odds3[5], self.twiddle5c); - - odds1[6] = SseVector::mul_complex(odds1[6], self.twiddle6); - odds3[6] = SseVector::mul_complex(odds3[6], self.twiddle6c); - - odds1[7] = SseVector::mul_complex(odds1[7], self.twiddle7); - odds3[7] = SseVector::mul_complex(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - let mut temp6 = solo_fft2_f64(odds1[6], odds3[6]); - let mut temp7 = solo_fft2_f64(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - temp6[1] = self.rotate90.rotate(temp6[1]); - temp7[1] = self.rotate90.rotate(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - _mm_add_pd(evens[0], temp0[0]), - _mm_add_pd(evens[1], temp1[0]), - _mm_add_pd(evens[2], temp2[0]), - _mm_add_pd(evens[3], temp3[0]), - _mm_add_pd(evens[4], temp4[0]), - _mm_add_pd(evens[5], temp5[0]), - _mm_add_pd(evens[6], temp6[0]), - _mm_add_pd(evens[7], temp7[0]), - _mm_add_pd(evens[8], temp0[1]), - _mm_add_pd(evens[9], temp1[1]), - _mm_add_pd(evens[10], temp2[1]), - _mm_add_pd(evens[11], temp3[1]), - _mm_add_pd(evens[12], temp4[1]), - _mm_add_pd(evens[13], temp5[1]), - _mm_add_pd(evens[14], temp6[1]), - _mm_add_pd(evens[15], temp7[1]), - _mm_sub_pd(evens[0], temp0[0]), - _mm_sub_pd(evens[1], temp1[0]), - _mm_sub_pd(evens[2], temp2[0]), - _mm_sub_pd(evens[3], temp3[0]), - _mm_sub_pd(evens[4], temp4[0]), - _mm_sub_pd(evens[5], temp5[0]), - _mm_sub_pd(evens[6], temp6[0]), - _mm_sub_pd(evens[7], temp7[0]), - _mm_sub_pd(evens[8], temp0[1]), - _mm_sub_pd(evens[9], temp1[1]), - _mm_sub_pd(evens[10], temp2[1]), - _mm_sub_pd(evens[11], temp3[1]), - _mm_sub_pd(evens[12], temp4[1]), - _mm_sub_pd(evens[13], temp5[1]), - _mm_sub_pd(evens[14], temp6[1]), - _mm_sub_pd(evens[15], temp7[1]), - ] + let out2 = self.bf8.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2]]); + store(2, out2); + + let out3 = self.bf8.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3]]); + store(3, out3); } } #[cfg(test)] mod unit_tests { use super::*; - use crate::algorithm::Dft; - use crate::test_utils::{check_fft_algorithm, compare_vectors}; + use crate::{algorithm::Dft, test_utils::{check_fft_algorithm, compare_vectors}}; //the tests for all butterflies will be identical except for the identifiers used and size //so it's ideal for a macro @@ -3653,7 +3237,7 @@ mod unit_tests { dft.process(&mut val_a); dft.process(&mut val_b); - let res_both = bf4.perform_parallel_fft_direct(p1, p2, p3, p4); + let res_both = bf4.perform_parallel_fft_direct([p1, p2, p3, p4]); let res = std::mem::transmute::<[__m128; 4], [Complex; 8]>(res_both); let sse_res_a = [res[0], res[2], res[4], res[6]]; diff --git a/src/sse/sse_utils.rs b/src/sse/sse_utils.rs index b6473e48..be4ac5f8 100644 --- a/src/sse/sse_utils.rs +++ b/src/sse/sse_utils.rs @@ -63,6 +63,27 @@ impl Rotate90F32 { let temp = _mm_shuffle_ps(values, values, 0xB1); _mm_xor_ps(temp, self.sign_both) } + + #[inline(always)] + pub unsafe fn rotate_both_45(&self, values: __m128) -> __m128 { + let rotated = self.rotate_both(values); + let sum = _mm_add_ps(rotated, values); + _mm_mul_ps(sum, _mm_set1_ps(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_135(&self, values: __m128) -> __m128 { + let rotated = self.rotate_both(values); + let diff = _mm_sub_ps(rotated, values); + _mm_mul_ps(diff, _mm_set1_ps(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_225(&self, values: __m128) -> __m128 { + let rotated = self.rotate_both(values); + let diff = _mm_add_ps(rotated, values); + _mm_mul_ps(diff, _mm_set1_ps(-(0.5f32.sqrt()))) + } } // Pack low (1st) complex @@ -171,6 +192,27 @@ impl Rotate90F64 { let temp = _mm_shuffle_pd(values, values, 0x01); _mm_xor_pd(temp, self.sign) } + + #[inline(always)] + pub unsafe fn rotate_45(&self, values: __m128d) -> __m128d { + let rotated = self.rotate(values); + let sum = _mm_add_pd(rotated, values); + _mm_mul_pd(sum, _mm_set1_pd(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_135(&self, values: __m128d) -> __m128d { + let rotated = self.rotate(values); + let diff = _mm_sub_pd(rotated, values); + _mm_mul_pd(diff, _mm_set1_pd(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_225(&self, values: __m128d) -> __m128d { + let rotated = self.rotate(values); + let diff = _mm_add_pd(rotated, values); + _mm_mul_pd(diff, _mm_set1_pd(-(0.5f64.sqrt()))) + } } #[cfg(test)] diff --git a/src/sse/sse_vector.rs b/src/sse/sse_vector.rs index c223d016..455bac2e 100644 --- a/src/sse/sse_vector.rs +++ b/src/sse/sse_vector.rs @@ -148,6 +148,9 @@ pub trait SseVector: Copy + Debug + Send + Sync { unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); + // math ops + unsafe fn neg(a: Self) -> Self; + /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times. /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector unsafe fn make_mixedradix_twiddle_chunk( @@ -207,6 +210,11 @@ impl SseVector for __m128 { _mm_storeh_pd(ptr as *mut f64, _mm_castps_pd(data)); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + _mm_xor_ps(a, _mm_set1_ps(-0.0)) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, @@ -308,6 +316,11 @@ impl SseVector for __m128d { unimplemented!("Impossible to do a partial store of complex f64's"); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + _mm_xor_pd(a, _mm_set1_pd(-0.0)) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, From 0218846704e18464a3cc9912e1df76bfe1856b98 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 16:33:59 -0800 Subject: [PATCH 02/13] Display the first incorrect element on test failure --- src/test_utils.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 1a3d6bf9..1bbb48bd 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -41,6 +41,15 @@ pub fn compare_vectors(vec1: &[Complex], vec2: &[Complex(vec1: &[Complex], vec2: &[Complex]) -> Option { + assert_eq!(vec1.len(), vec2.len()); + for (i, (&a, &b)) in vec1.iter().zip(vec2.iter()).enumerate() { + if (a - b).norm().to_f64().unwrap() > 0.1 { + return Some(i); + } + } + None +} #[allow(unused)] fn transppose_diagnostic(expected: &[Complex], actual: &[Complex]) { @@ -97,8 +106,10 @@ pub fn check_fft_algorithm( if !compare_vectors(&expected_output, &buffer) { panic!( - "process() failed, length = {}, direction = {}", - len, direction + "process() failed, length = {}, direction = {}, first diff = {:?}", + len, + direction, + first_diff(&expected_output, &buffer) ); } } From 8fa0a13a260cc2e4b13d70cd9db56a2fdea36090 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 19:49:04 -0800 Subject: [PATCH 03/13] use set_ps to simplify reasoning about packing code --- src/sse/sse_butterflies.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 3b1bb3ce..388ec474 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -12,15 +12,15 @@ use crate::{Direction, Fft, Length}; use super::sse_common::{assert_f32, assert_f64}; use super::sse_utils::*; -use super::sse_vector::{SseArray, SseArrayMut, SseVector}; +use super::sse_vector::{SseArrayMut, SseVector}; #[inline(always)] unsafe fn pack_32(a: Complex, b: Complex) -> __m128 { - [a,b].as_slice().load_complex(0) + _mm_set_ps(b.im, b.re, a.im, a.re) } #[inline(always)] unsafe fn pack_64(a: Complex) -> __m128d { - [a].as_slice().load_complex(0) + _mm_set_pd(a.im, a.re) } From aacfa758e0d53ab9ce77af638949286f4202b365 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 20:50:52 -0800 Subject: [PATCH 04/13] Improved comments for the new SSE butterflies --- src/sse/sse_butterflies.rs | 139 +++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 45 deletions(-) diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 388ec474..72dadc31 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -2333,6 +2333,13 @@ impl SseF32Butterfly16 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 4), @@ -2340,6 +2347,7 @@ impl SseF32Butterfly16 { buffer.load_complex(i + 12), ]; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, and transpose let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); tmp0[1] = SseVector::mul_complex(tmp0[1], self.twiddles_packed[0]); tmp0[2] = SseVector::mul_complex(tmp0[2], self.twiddles_packed[1]); @@ -2354,13 +2362,14 @@ impl SseF32Butterfly16 { let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); let [mid6, mid7] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); - // cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i: usize, vectors: [__m128; 4]| { buffer.store_complex(vectors[0], i + 0); buffer.store_complex(vectors[1], i + 4); buffer.store_complex(vectors[2], i + 8); buffer.store_complex(vectors[3], i + 12); }; + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf4.perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); store(0, out0); @@ -2371,12 +2380,13 @@ impl SseF32Butterfly16 { // benchmarking shows it's faster to always use the nonparallel version, but this is kep around for reference #[allow(unused)] pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 4x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 16)); let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 4), buffer.load_complex(i + 20)); @@ -2385,6 +2395,7 @@ impl SseF32Butterfly16 { [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors let [in2, in3] = load(2); let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); @@ -2395,6 +2406,7 @@ impl SseF32Butterfly16 { tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddle6); tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); + // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill let [in0, in1] = load(0); let tmp0 = self.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); @@ -2402,7 +2414,7 @@ impl SseF32Butterfly16 { tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, values_a: [__m128; 4], values_b: [__m128; 4]| { for n in 0..4 { let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); @@ -2410,6 +2422,7 @@ impl SseF32Butterfly16 { buffer.store_complex(b, i + n*4 + 16); } }; + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf4.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); let out1 = self.bf4.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); store(0, out0, out1); @@ -2457,12 +2470,13 @@ impl SseF64Butterfly16 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 4x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 4), @@ -2470,6 +2484,7 @@ impl SseF64Butterfly16 { buffer.load_complex(i + 12), ]; + // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf4.perform_fft_direct(load(1)); tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); tmp1[2] = self.bf4.rotate.rotate_45(tmp1[2]); @@ -2485,9 +2500,10 @@ impl SseF64Butterfly16 { tmp2[2] = self.bf4.rotate.rotate(tmp2[2]); tmp2[3] = self.bf4.rotate.rotate_135(tmp2[3]); + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill let tmp0 = self.bf4.perform_fft_direct(load(0)); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i: usize, vectors: [__m128d; 4]| { buffer.store_complex(vectors[0], i + 0); buffer.store_complex(vectors[1], i + 4); @@ -2495,6 +2511,7 @@ impl SseF64Butterfly16 { buffer.store_complex(vectors[3], i + 12); }; + // Size-4 FFTs down each of our transposed columns, storing them as soon as we're done with them let out0 = self.bf4.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); store(0, out0); @@ -2576,12 +2593,13 @@ impl SseF32Butterfly24 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 8x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 6), @@ -2589,6 +2607,7 @@ impl SseF32Butterfly24 { buffer.load_complex(i + 18), ]; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, transpose let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddles_packed[3]); tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddles_packed[4]); @@ -2610,7 +2629,7 @@ impl SseF32Butterfly24 { let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); let [mid6, mid7] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors: [__m128; 6]| { buffer.store_complex(vectors[0], i); buffer.store_complex(vectors[1], i + 4); @@ -2620,6 +2639,7 @@ impl SseF32Butterfly24 { buffer.store_complex(vectors[5], i + 20); }; + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf6.perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); store(0, out0); @@ -2629,6 +2649,13 @@ impl SseF32Butterfly24 { #[inline(always)] pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 24)); let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 6), buffer.load_complex(i + 30)); @@ -2637,6 +2664,7 @@ impl SseF32Butterfly24 { [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors let [in0, in1] = load(0); let tmp0 = self.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); @@ -2664,7 +2692,7 @@ impl SseF32Butterfly24 { tmp5[2] = SseVector::mul_complex(tmp5[2], self.twiddle10); tmp5[3] = self.bf4.rotate.rotate_both_225(tmp5[3]); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors_a: [__m128; 6], vectors_b: [__m128; 6]| { for n in 0..6 { let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); @@ -2673,6 +2701,7 @@ impl SseF32Butterfly24 { } }; + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf6.perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); let out1 = self.bf6.perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); store(0, out0, out1); @@ -2733,12 +2762,13 @@ impl SseF64Butterfly24 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 8x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 6), @@ -2746,6 +2776,7 @@ impl SseF64Butterfly24 { buffer.load_complex(i + 18), ]; + // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf4.perform_fft_direct(load(1)); tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); @@ -2771,9 +2802,10 @@ impl SseF64Butterfly24 { tmp3[2] = self.bf4.rotate.rotate(tmp3[2]); tmp3[3] = self.bf4.rotate.rotate_135(tmp3[3]); + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill let tmp0 = self.bf4.perform_fft_direct(load(0)); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors: [__m128d; 6]| { buffer.store_complex(vectors[0], i); buffer.store_complex(vectors[1], i + 4); @@ -2783,6 +2815,7 @@ impl SseF64Butterfly24 { buffer.store_complex(vectors[5], i + 20); }; + // Size-6 FFTs down each of our transposed columns, storing them as soon as we're done with them let out0 = self.bf6.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); store(0, out0); @@ -2880,12 +2913,13 @@ impl SseF32Butterfly32 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 8x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 8), @@ -2893,6 +2927,7 @@ impl SseF32Butterfly32 { buffer.load_complex(i + 24), ]; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp0 = self.bf8.bf4.perform_parallel_fft_direct(load(0)); tmp0[1] = SseVector::mul_complex(tmp0[1], self.twiddles_packed[0]); tmp0[2] = SseVector::mul_complex(tmp0[2], self.twiddles_packed[1]); @@ -2921,7 +2956,7 @@ impl SseF32Butterfly32 { let [mid6, mid7] = transpose_complex_2x2_f32(tmp3[0], tmp3[1]); let [mid14, mid15] = transpose_complex_2x2_f32(tmp3[2], tmp3[3]); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors: [__m128; 8]| { buffer.store_complex(vectors[0], i); buffer.store_complex(vectors[1], i + 4); @@ -2933,6 +2968,7 @@ impl SseF32Butterfly32 { buffer.store_complex(vectors[7], i + 28); }; + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf8.perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); store(0, out0); @@ -2942,6 +2978,13 @@ impl SseF32Butterfly32 { #[inline(always)] pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl SseArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 32)); let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 40)); @@ -2950,6 +2993,7 @@ impl SseF32Butterfly32 { [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors let [in0, in1] = load(0); let tmp0 = self.bf8.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(in1); @@ -2987,7 +3031,7 @@ impl SseF32Butterfly32 { tmp7[2] = SseVector::mul_complex(tmp7[2], self.twiddle14); tmp7[3] = SseVector::mul_complex(tmp7[3], self.twiddle21); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors_a: [__m128; 8], vectors_b: [__m128; 8]| { for n in 0..8 { let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); @@ -2996,6 +3040,7 @@ impl SseF32Butterfly32 { } }; + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them let out0 = self.bf8.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); let out1 = self.bf8.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1]]); store(0, out0, out1); @@ -3071,12 +3116,13 @@ impl SseF64Butterfly32 { #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl SseArrayMut) { - // we're going to hardcode a step of 8x4 mixed radix - // step 1: transpose (skipped since the vectors mean our data is already in the correct format) - // and - // step 2: column FFTs - // and - // step 3: twiddle factors + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i| [ buffer.load_complex(i), buffer.load_complex(i + 8), @@ -3084,6 +3130,7 @@ impl SseF64Butterfly32 { buffer.load_complex(i + 24), ]; + // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf8.bf4.perform_fft_direct(load(1)); tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); @@ -3119,9 +3166,10 @@ impl SseF64Butterfly32 { tmp4[2] = self.bf8.bf4.rotate.rotate(tmp4[2]); tmp4[3] = self.bf8.bf4.rotate.rotate_135(tmp4[3]); + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill let tmp0 = self.bf8.bf4.perform_fft_direct(load(0)); - // step 4 and 5: transpose and cross FFTs + //////////////////////////////////////////////////////////// let mut store = |i, vectors: [__m128d; 8]| { buffer.store_complex(vectors[0], i); buffer.store_complex(vectors[1], i + 4); @@ -3133,6 +3181,7 @@ impl SseF64Butterfly32 { buffer.store_complex(vectors[7], i + 28); }; + // Size-8 FFTs down each of our transposed columns, storing them as soon as we're done with them let out0 = self.bf8.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); store(0, out0); From e6c22f3dac4af9ded1e5ad199bd465438b231e4f Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 22:23:10 -0800 Subject: [PATCH 05/13] Optimized wasm simd butterflies to be more cache-friendly --- benches/bench_rustfft_wasm_simd.rs | 8 + src/wasm_simd/wasm_simd_butterflies.rs | 1973 ++++++++++-------------- src/wasm_simd/wasm_simd_utils.rs | 42 + src/wasm_simd/wasm_simd_vector.rs | 13 + 4 files changed, 899 insertions(+), 1137 deletions(-) diff --git a/benches/bench_rustfft_wasm_simd.rs b/benches/bench_rustfft_wasm_simd.rs index 02a34663..657bee18 100644 --- a/benches/bench_rustfft_wasm_simd.rs +++ b/benches/bench_rustfft_wasm_simd.rs @@ -157,6 +157,10 @@ fn wasm_simd_butterfly32_23(b: &mut Bencher) { bench_planned_multi_f32(b, 23); } #[bench] +fn wasm_simd_butterfly32_24(b: &mut Bencher) { + bench_planned_multi_f32(b, 24); +} +#[bench] fn wasm_simd_butterfly32_29(b: &mut Bencher) { bench_planned_multi_f32(b, 29); } @@ -238,6 +242,10 @@ fn wasm_simd_butterfly64_23(b: &mut Bencher) { bench_planned_multi_f64(b, 23); } #[bench] +fn wasm_simd_butterfly64_24(b: &mut Bencher) { + bench_planned_multi_f64(b, 24); +} +#[bench] fn wasm_simd_butterfly64_29(b: &mut Bencher) { bench_planned_multi_f64(b, 29); } diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index 32086b58..e756e132 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -12,14 +12,14 @@ use crate::{Direction, Fft, Length}; use super::wasm_simd_common::{assert_f32, assert_f64}; use super::wasm_simd_utils::*; -use super::wasm_simd_vector::WasmSimdArrayMut; +use super::wasm_simd_vector::{WasmSimdArrayMut, WasmVector, WasmVector32, WasmVector64}; #[inline(always)] -unsafe fn pack32(a: Complex, b: Complex) -> v128 { +unsafe fn pack_32(a: Complex, b: Complex) -> v128 { f32x4(a.re, a.im, b.re, b.im) } #[inline(always)] -unsafe fn pack64(a: Complex) -> v128 { +unsafe fn pack_64(a: Complex) -> v128 { f64x2(a.re, a.im) } @@ -700,7 +700,7 @@ impl WasmSimdF32Butterfly4 { let [value0ab, value1ab] = transpose_complex_2x2_f32(value01a, value01b); let [value2ab, value3ab] = transpose_complex_2x2_f32(value23a, value23b); - let out = self.perform_parallel_fft_direct(value0ab, value1ab, value2ab, value3ab); + let out = self.perform_parallel_fft_direct([value0ab, value1ab, value2ab, value3ab]); let [out0, out1] = transpose_complex_2x2_f32(out[0], out[1]); let [out2, out3] = transpose_complex_2x2_f32(out[2], out[3]); @@ -735,21 +735,15 @@ impl WasmSimdF32Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct( - &self, - values0: v128, - values1: v128, - values2: v128, - values3: v128, - ) -> [v128; 4] { + pub(crate) unsafe fn perform_parallel_fft_direct(&self, values: [v128; 4]) -> [v128; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = parallel_fft2_interleaved_f32(values0, values2); - let mut temp1 = parallel_fft2_interleaved_f32(values1, values3); + let temp0 = parallel_fft2_interleaved_f32(values[0], values[2]); + let mut temp1 = parallel_fft2_interleaved_f32(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate_both(temp1[1]); @@ -812,7 +806,7 @@ impl WasmSimdF64Butterfly4 { let value2 = buffer.load_complex_v128(2); let value3 = buffer.load_complex_v128(3); - let out = self.perform_fft_direct(value0, value1, value2, value3); + let out = self.perform_fft_direct([value0, value1, value2, value3]); buffer.store_complex_v128(out[0], 0); buffer.store_complex_v128(out[1], 1); @@ -821,21 +815,15 @@ impl WasmSimdF64Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: v128, - value1: v128, - value2: v128, - value3: v128, - ) -> [v128; 4] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [v128; 4]) -> [v128; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = solo_fft2_f64(value0, value2); - let mut temp1 = solo_fft2_f64(value1, value3); + let temp0 = solo_fft2_f64(values[0], values[2]); + let mut temp1 = solo_fft2_f64(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate(temp1[1]); @@ -1309,7 +1297,7 @@ impl WasmSimdF64Butterfly6 { let value4 = buffer.load_complex_v128(4); let value5 = buffer.load_complex_v128(5); - let out = self.perform_fft_direct(value0, value1, value2, value3, value4, value5); + let out = self.perform_fft_direct([value0, value1, value2, value3, value4, value5]); buffer.store_complex_v128(out[0], 0); buffer.store_complex_v128(out[1], 1); @@ -1320,20 +1308,12 @@ impl WasmSimdF64Butterfly6 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: v128, - value1: v128, - value2: v128, - value3: v128, - value4: v128, - value5: v128, - ) -> [v128; 6] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [v128; 6]) -> [v128; 6] { // Algorithm: 3x2 good-thomas // Size-3 FFTs down the columns of our reordered array - let mid0 = self.bf3.perform_fft_direct(value0, value2, value4); - let mid1 = self.bf3.perform_fft_direct(value3, value5, value1); + let mid0 = self.bf3.perform_fft_direct(values[0], values[2], values[4]); + let mid1 = self.bf3.perform_fft_direct(values[3], values[5], values[1]); // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors @@ -1458,10 +1438,10 @@ impl WasmSimdF32Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_parallel_fft_direct(values[0], values[2], values[4], values[6]); + .perform_parallel_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_parallel_fft_direct(values[1], values[3], values[5], values[7]); + .perform_parallel_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors let val5b = self.rotate90.rotate_both(val47[1]); @@ -1547,10 +1527,10 @@ impl WasmSimdF64Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_fft_direct(values[0], values[2], values[4], values[6]); + .perform_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_fft_direct(values[1], values[3], values[5], values[7]); + .perform_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors let val5b = self.rotate90.rotate(val47[1]); @@ -2096,13 +2076,13 @@ impl WasmSimdF32Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_parallel_fft_direct(values[0], values[3], values[6], values[9]); + .perform_parallel_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_parallel_fft_direct(values[4], values[7], values[10], values[1]); + .perform_parallel_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_parallel_fft_direct(values[8], values[11], values[2], values[5]); + .perform_parallel_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2182,13 +2162,13 @@ impl WasmSimdF64Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_fft_direct(values[0], values[3], values[6], values[9]); + .perform_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_fft_direct(values[4], values[7], values[10], values[1]); + .perform_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_fft_direct(values[8], values[11], values[2], values[5]); + .perform_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2431,208 +2411,192 @@ impl WasmSimdF64Butterfly15 { // pub struct WasmSimdF32Butterfly16 { - direction: FftDirection, bf4: WasmSimdF32Butterfly4, - bf8: WasmSimdF32Butterfly8, - rotate90: Rotate90F32, - twiddle01: v128, - twiddle23: v128, - twiddle01conj: v128, - twiddle23conj: v128, + twiddles_packed: [v128; 6], twiddle1: v128, twiddle2: v128, twiddle3: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle3c: v128, + twiddle6: v128, + twiddle9: v128, } boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly16, 16, - |this: &WasmSimdF32Butterfly16<_>| this.direction + |this: &WasmSimdF32Butterfly16<_>| this.bf4.direction ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF32Butterfly16, 16, - |this: &WasmSimdF32Butterfly16<_>| this.direction + |this: &WasmSimdF32Butterfly16<_>| this.bf4.direction ); impl WasmSimdF32Butterfly16 { - #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = WasmSimdF32Butterfly8::new(direction); - let bf4 = WasmSimdF32Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); let tw2: Complex = twiddles::compute_twiddle(2, 16, direction); let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); - let twiddle01 = f32x4(1.0, 0.0, tw1.re, tw1.im); - let twiddle23 = f32x4(tw2.re, tw2.im, tw3.re, tw3.im); - let twiddle01conj = f32x4(1.0, 0.0, tw1.re, -tw1.im); - let twiddle23conj = f32x4(tw2.re, -tw2.im, tw3.re, -tw3.im); - let twiddle1 = f32x4(tw1.re, tw1.im, tw1.re, tw1.im); - let twiddle2 = f32x4(tw2.re, tw2.im, tw2.re, tw2.im); - let twiddle3 = f32x4(tw3.re, tw3.im, tw3.re, tw3.im); - let twiddle1c = f32x4(tw1.re, -tw1.im, tw1.re, -tw1.im); - let twiddle2c = f32x4(tw2.re, -tw2.im, tw2.re, -tw2.im); - let twiddle3c = f32x4(tw3.re, -tw3.im, tw3.re, -tw3.im); - Self { - direction, - bf4, - bf8, - rotate90, - twiddle01, - twiddle23, - twiddle01conj, - twiddle23conj, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + let tw4: Complex = twiddles::compute_twiddle(4, 16, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); + + unsafe { + Self { + bf4: WasmSimdF32Butterfly4::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle6: pack_32(tw6, tw6), + twiddle9: pack_32(tw9, tw9), + } } } #[inline(always)] - unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14 }); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7}); + unsafe fn load_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [v128; 4] { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 4).0, + buffer.load_complex(i + 8).0, + buffer.load_complex(i + 12).0, + ] } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl WasmSimdArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - let values = interleave_complex_f32!(input_packed, 8, {0, 1, 2, 3 ,4 ,5 ,6 ,7}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11,12,13,14, 15}); + unsafe fn store_chunk(buffer: &mut impl WasmSimdArrayMut, i: usize, vectors: [v128; 4]) { + buffer.store_complex(WasmVector32(vectors[0]), i + 0); + buffer.store_complex(WasmVector32(vectors[1]), i + 4); + buffer.store_complex(WasmVector32(vectors[2]), i + 8); + buffer.store_complex(WasmVector32(vectors[3]), i + 12); } #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 8]) -> [v128; 8] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1503 = extract_hi_hi_f32(input[7], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - - let in_evens = [in0002, in0406, in0810, in1214]; - - // step 2: column FFTs - let evens = self.bf8.perform_fft_direct(in_evens); - let mut odds1 = self.bf4.perform_fft_direct(in0105, in0913); - let mut odds3 = self.bf4.perform_fft_direct(in1503, in0711); - - // step 3: apply twiddle factors - odds1[0] = mul_complex_f32(odds1[0], self.twiddle01); - odds3[0] = mul_complex_f32(odds3[0], self.twiddle01conj); - - odds1[1] = mul_complex_f32(odds1[1], self.twiddle23); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle23conj); + unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, and transpose + let mut tmp0 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 0)); + tmp0[1] = mul_complex_f32(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = mul_complex_f32(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = mul_complex_f32(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 2)); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = mul_complex_f32(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf4 + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); + Self::store_chunk(&mut buffer, 0, out0); - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); + let out1 = self + .bf4 + .perform_parallel_fft_direct([mid4, mid5, mid6, mid7]); + Self::store_chunk(&mut buffer, 2, out1); + } - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); + #[inline(always)] + unsafe fn load_parallel_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [[v128; 4]; 2] { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0).0, buffer.load_complex(i + 16).0); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 4).0, buffer.load_complex(i + 20).0); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8).0, buffer.load_complex(i + 24).0); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12).0, buffer.load_complex(i + 28).0); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + } - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp0[1]), - f32x4_add(evens[3], temp1[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp0[1]), - f32x4_sub(evens[3], temp1[1]), - ] + #[inline(always)] + unsafe fn store_parallel_chunk( + buffer: &mut impl WasmSimdArrayMut, + i: usize, + values_a: [v128; 4], + values_b: [v128; 4], + ) { + for n in 0..4 { + let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); + buffer.store_complex(WasmVector32(a), i + n * 4); + buffer.store_complex(WasmVector32(b), i + n * 4 + 16); + } } #[inline(always)] - unsafe fn perform_parallel_fft_direct(&self, input: [v128; 16]) -> [v128; 16] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf8.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl WasmSimdArrayMut, + ) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in2, in3] = Self::load_parallel_chunk(&buffer, 2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = mul_complex_f32(tmp2[1], self.twiddle2); + tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); + tmp2[3] = mul_complex_f32(tmp2[3], self.twiddle6); + tmp3[1] = mul_complex_f32(tmp3[1], self.twiddle3); + tmp3[2] = mul_complex_f32(tmp3[2], self.twiddle6); + tmp3[3] = mul_complex_f32(tmp3[3], self.twiddle9); + + // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill + let [in0, in1] = Self::load_parallel_chunk(&buffer, 0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddle1); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddle2); + tmp1[3] = mul_complex_f32(tmp1[3], self.twiddle3); + + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self .bf4 - .perform_parallel_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self + .perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + let out1 = self .bf4 - .perform_parallel_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = mul_complex_f32(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle1c); + .perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + Self::store_parallel_chunk(&mut buffer, 0, out0, out1); - odds1[2] = mul_complex_f32(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f32(odds3[2], self.twiddle2c); - - odds1[3] = mul_complex_f32(odds1[3], self.twiddle3); - odds3[3] = mul_complex_f32(odds3[3], self.twiddle3c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp2[0]), - f32x4_add(evens[3], temp3[0]), - f32x4_add(evens[4], temp0[1]), - f32x4_add(evens[5], temp1[1]), - f32x4_add(evens[6], temp2[1]), - f32x4_add(evens[7], temp3[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp2[0]), - f32x4_sub(evens[3], temp3[0]), - f32x4_sub(evens[4], temp0[1]), - f32x4_sub(evens[5], temp1[1]), - f32x4_sub(evens[6], temp2[1]), - f32x4_sub(evens[7], temp3[1]), - ] + let out2 = self + .bf4 + .perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + let out3 = self + .bf4 + .perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + Self::store_parallel_chunk(&mut buffer, 2, out2, out3); } } - // _ __ __ _ _ _ _ _ // / |/ /_ / /_ | || | | |__ (_) |_ // | | '_ \ _____ | '_ \| || |_| '_ \| | __| @@ -2641,155 +2605,105 @@ impl WasmSimdF32Butterfly16 { // pub struct WasmSimdF64Butterfly16 { - direction: FftDirection, bf4: WasmSimdF64Butterfly4, - bf8: WasmSimdF64Butterfly8, - rotate90: Rotate90F64, twiddle1: v128, - twiddle2: v128, twiddle3: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle3c: v128, + twiddle9: v128, } boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly16, 16, - |this: &WasmSimdF64Butterfly16<_>| this.direction + |this: &WasmSimdF64Butterfly16<_>| this.bf4.direction ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF64Butterfly16, 16, - |this: &WasmSimdF64Butterfly16<_>| this.direction + |this: &WasmSimdF64Butterfly16<_>| this.bf4.direction ); impl WasmSimdF64Butterfly16 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = WasmSimdF64Butterfly8::new(direction); - let bf4 = WasmSimdF64Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = unsafe { - v128_load( - &twiddles::compute_twiddle::(1, 16, direction) as *const _ as *const v128, - ) - }; - let twiddle2 = unsafe { - v128_load( - &twiddles::compute_twiddle::(2, 16, direction) as *const _ as *const v128, - ) - }; - let twiddle3 = unsafe { - v128_load( - &twiddles::compute_twiddle::(3, 16, direction) as *const _ as *const v128, - ) - }; - let twiddle1c = unsafe { - v128_load( - &twiddles::compute_twiddle::(1, 16, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle2c = unsafe { - v128_load( - &twiddles::compute_twiddle::(2, 16, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle3c = unsafe { - v128_load( - &twiddles::compute_twiddle::(3, 16, direction).conj() as *const _ - as *const v128, - ) - }; + let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); - Self { - direction, - bf4, - bf8, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + unsafe { + Self { + bf4: WasmSimdF64Butterfly4::new(direction), + twiddle1: pack_64(tw1), + twiddle3: pack_64(tw3), + twiddle9: pack_64(tw9), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let values = - read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - } + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 4).0, + buffer.load_complex(i + 8).0, + buffer.load_complex(i + 12).0, + ] + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 16]) -> [v128; 16] { - // we're going to hardcode a step of split radix + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = mul_complex_f64(tmp1[1], self.twiddle1); + tmp1[2] = self.bf4.rotate.rotate_45(tmp1[2]); + tmp1[3] = mul_complex_f64(tmp1[3], self.twiddle3); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = mul_complex_f64(tmp3[1], self.twiddle3); + tmp3[2] = self.bf4.rotate.rotate_135(tmp3[2]); + tmp3[3] = mul_complex_f64(tmp3[3], self.twiddle9); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = self.bf4.rotate.rotate_45(tmp2[1]); + tmp2[2] = self.bf4.rotate.rotate(tmp2[2]); + tmp2[3] = self.bf4.rotate.rotate_135(tmp2[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i: usize, vectors: [v128; 4]| { + buffer.store_complex(WasmVector64(vectors[0]), i + 0); + buffer.store_complex(WasmVector64(vectors[1]), i + 4); + buffer.store_complex(WasmVector64(vectors[2]), i + 8); + buffer.store_complex(WasmVector64(vectors[3]), i + 12); + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf8.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self - .bf4 - .perform_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self + // Size-4 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self .bf4 - .perform_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = mul_complex_f64(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f64(odds3[1], self.twiddle1c); - - odds1[2] = mul_complex_f64(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f64(odds3[2], self.twiddle2c); + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + store(0, out0); - odds1[3] = mul_complex_f64(odds1[3], self.twiddle3); - odds3[3] = mul_complex_f64(odds3[3], self.twiddle3c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); + let out1 = self + .bf4 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + store(1, out1); - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); + let out2 = self + .bf4 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + store(2, out2); - //step 5: copy/add/subtract data back to buffer - [ - f64x2_add(evens[0], temp0[0]), - f64x2_add(evens[1], temp1[0]), - f64x2_add(evens[2], temp2[0]), - f64x2_add(evens[3], temp3[0]), - f64x2_add(evens[4], temp0[1]), - f64x2_add(evens[5], temp1[1]), - f64x2_add(evens[6], temp2[1]), - f64x2_add(evens[7], temp3[1]), - f64x2_sub(evens[0], temp0[0]), - f64x2_sub(evens[1], temp1[0]), - f64x2_sub(evens[2], temp2[0]), - f64x2_sub(evens[3], temp3[0]), - f64x2_sub(evens[4], temp0[1]), - f64x2_sub(evens[5], temp1[1]), - f64x2_sub(evens[6], temp2[1]), - f64x2_sub(evens[7], temp3[1]), - ] + let out3 = self + .bf4 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + store(3, out3); } } @@ -2801,257 +2715,222 @@ impl WasmSimdF64Butterfly16 { // pub struct WasmSimdF32Butterfly24 { - direction: FftDirection, + bf4: WasmSimdF32Butterfly4, bf6: WasmSimdF32Butterfly6, - bf12: WasmSimdF32Butterfly12, - rotate90: Rotate90F32, - twiddle01: v128, - twiddle23: v128, - twiddle45: v128, - twiddle01conj: v128, - twiddle23conj: v128, - twiddle45conj: v128, + twiddles_packed: [v128; 9], twiddle1: v128, twiddle2: v128, twiddle4: v128, twiddle5: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle4c: v128, - twiddle5c: v128, + twiddle8: v128, + twiddle10: v128, } boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly24, 24, - |this: &WasmSimdF32Butterfly24<_>| { this.direction } + |this: &WasmSimdF32Butterfly24<_>| { this.bf4.direction } ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF32Butterfly24, 24, - |this: &WasmSimdF32Butterfly24<_>| this.direction + |this: &WasmSimdF32Butterfly24<_>| this.bf4.direction ); impl WasmSimdF32Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let tw0 = Complex { re: 1.0, im: 0.0 }; - let tw1 = twiddles::compute_twiddle(1, 24, direction); - let tw2 = twiddles::compute_twiddle(2, 24, direction); - let tw3 = twiddles::compute_twiddle(3, 24, direction); - let tw4 = twiddles::compute_twiddle(4, 24, direction); - let tw5 = twiddles::compute_twiddle(5, 24, direction); + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; + let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 24, direction); + let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 24, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 24, direction); unsafe { Self { - direction, + bf4: WasmSimdF32Butterfly4::new(direction), bf6: WasmSimdF32Butterfly6::new(direction), - bf12: WasmSimdF32Butterfly12::new(direction), - rotate90: Rotate90F32::new(direction == FftDirection::Inverse), - twiddle01: pack32(tw0, tw1), - twiddle23: pack32(tw2, tw3), - twiddle45: pack32(tw4, tw5), - twiddle01conj: pack32(tw0.conj(), tw1.conj()), - twiddle23conj: pack32(tw2.conj(), tw3.conj()), - twiddle45conj: pack32(tw4.conj(), tw5.conj()), - twiddle1: pack32(tw1, tw1), - twiddle2: pack32(tw2, tw2), - twiddle4: pack32(tw4, tw4), - twiddle5: pack32(tw5, tw5), - twiddle1c: pack32(tw1.conj(), tw1.conj()), - twiddle2c: pack32(tw2.conj(), tw2.conj()), - twiddle4c: pack32(tw4.conj(), tw4.conj()), - twiddle5c: pack32(tw5.conj(), tw5.conj()), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle4: pack_32(tw4, tw4), + twiddle5: pack_32(tw5, tw5), + twiddle8: pack_32(tw8, tw8), + twiddle10: pack_32(tw10, tw10), } } } #[inline(always)] - unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let input_packed = - read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11}); + unsafe fn load_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [v128; 4] { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 6).0, + buffer.load_complex(i + 12).0, + buffer.load_complex(i + 18).0, + ] } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl WasmSimdArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46}); - - let values = - interleave_complex_f32!(input_packed, 12, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = - separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }); + unsafe fn store_chunk(buffer: &mut impl WasmSimdArrayMut, i: usize, vectors: [v128; 6]) { + buffer.store_complex(WasmVector32(vectors[0]), i + 0); + buffer.store_complex(WasmVector32(vectors[1]), i + 4); + buffer.store_complex(WasmVector32(vectors[2]), i + 8); + buffer.store_complex(WasmVector32(vectors[3]), i + 12); + buffer.store_complex(WasmVector32(vectors[4]), i + 16); + buffer.store_complex(WasmVector32(vectors[5]), i + 20); } #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 12]) -> [v128; 12] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - - let in2303 = extract_hi_hi_f32(input[11], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - - let in_evens = [in0002, in0406, in0810, in1214, in1618, in2022]; - - // step 2: column FFTs - let evens = self.bf12.perform_fft_direct(in_evens); - let mut odds1 = self.bf6.perform_fft_direct(in0105, in0913, in1721); - let mut odds3 = self.bf6.perform_fft_direct(in2303, in0711, in1519); - - // step 3: apply twiddle factors - odds1[0] = mul_complex_f32(odds1[0], self.twiddle01); - odds3[0] = mul_complex_f32(odds3[0], self.twiddle01conj); - - odds1[1] = mul_complex_f32(odds1[1], self.twiddle23); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle23conj); - - odds1[2] = mul_complex_f32(odds1[2], self.twiddle45); - odds3[2] = mul_complex_f32(odds3[2], self.twiddle45conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp2[0]), - f32x4_add(evens[3], temp0[1]), - f32x4_add(evens[4], temp1[1]), - f32x4_add(evens[5], temp2[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp2[0]), - f32x4_sub(evens[3], temp0[1]), - f32x4_sub(evens[4], temp1[1]), - f32x4_sub(evens[5], temp2[1]), - ] + unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, transpose + let mut tmp1 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 2)); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = mul_complex_f32(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 4)); + tmp2[1] = mul_complex_f32(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = mul_complex_f32(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = mul_complex_f32(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp0 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 0)); + tmp0[1] = mul_complex_f32(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = mul_complex_f32(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = mul_complex_f32(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); + Self::store_chunk(&mut buffer, 0, out0); + + let out1 = self + .bf6 + .perform_parallel_fft_direct(mid6, mid7, mid8, mid9, mid10, mid11); + Self::store_chunk(&mut buffer, 2, out1); + } + + #[inline(always)] + unsafe fn load_parallel_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [[v128; 4]; 2] { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0).0, buffer.load_complex(i + 24).0); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 6).0, buffer.load_complex(i + 30).0); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12).0, buffer.load_complex(i + 36).0); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 18).0, buffer.load_complex(i + 42).0); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + } + + #[inline(always)] + unsafe fn store_parallel_chunk( + buffer: &mut impl WasmSimdArrayMut, + i: usize, + values_a: [v128; 6], + values_b: [v128; 6], + ) { + for n in 0..6 { + let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); + buffer.store_complex(WasmVector32(a), i + n * 4); + buffer.store_complex(WasmVector32(b), i + n * 4 + 24); + } } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct(&self, input: [v128; 24]) -> [v128; 24] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf12.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_parallel_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_parallel_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); - - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = f32x4_add(vec, rotated); - f32x4_mul( - sum, - v128_load32_splat(&0.5f32.sqrt() as *const f32 as *const u32), - ) - }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = f32x4_sub(vec, rotated); - f32x4_mul( - sum, - v128_load32_splat(&0.5f32.sqrt() as *const f32 as *const u32), - ) - }; - - // step 3: apply twiddle factors - odds1[1] = mul_complex_f32(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle1c); - - odds1[2] = mul_complex_f32(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f32(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = mul_complex_f32(odds1[4], self.twiddle4); - odds3[4] = mul_complex_f32(odds3[4], self.twiddle4c); - - odds1[5] = mul_complex_f32(odds1[5], self.twiddle5); - odds3[5] = mul_complex_f32(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp2[0]), - f32x4_add(evens[3], temp3[0]), - f32x4_add(evens[4], temp4[0]), - f32x4_add(evens[5], temp5[0]), - f32x4_add(evens[6], temp0[1]), - f32x4_add(evens[7], temp1[1]), - f32x4_add(evens[8], temp2[1]), - f32x4_add(evens[9], temp3[1]), - f32x4_add(evens[10], temp4[1]), - f32x4_add(evens[11], temp5[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp2[0]), - f32x4_sub(evens[3], temp3[0]), - f32x4_sub(evens[4], temp4[0]), - f32x4_sub(evens[5], temp5[0]), - f32x4_sub(evens[6], temp0[1]), - f32x4_sub(evens[7], temp1[1]), - f32x4_sub(evens[8], temp2[1]), - f32x4_sub(evens[9], temp3[1]), - f32x4_sub(evens[10], temp4[1]), - f32x4_sub(evens[11], temp5[1]), - ] + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl WasmSimdArrayMut, + ) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in0, in1] = Self::load_parallel_chunk(&buffer, 0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddle1); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_both_45(tmp1[3]); + + let [in2, in3] = Self::load_parallel_chunk(&buffer, 2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = mul_complex_f32(tmp2[1], self.twiddle2); + tmp2[2] = mul_complex_f32(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate_both(tmp2[3]); + tmp3[1] = self.bf4.rotate.rotate_both_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate_both(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_both_135(tmp3[3]); + + let [in4, in5] = Self::load_parallel_chunk(&buffer, 4); + let mut tmp4 = self.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = mul_complex_f32(tmp4[1], self.twiddle4); + tmp4[2] = mul_complex_f32(tmp4[2], self.twiddle8); + tmp4[3] = WasmVector::neg(WasmVector32(tmp4[3])).0; + tmp5[1] = mul_complex_f32(tmp5[1], self.twiddle5); + tmp5[2] = mul_complex_f32(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_both_225(tmp5[3]); + + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); + let out1 = self + .bf6 + .perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); + Self::store_parallel_chunk(&mut buffer, 0, out0, out1); + + let out2 = self + .bf6 + .perform_parallel_fft_direct(tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]); + let out3 = self + .bf6 + .perform_parallel_fft_direct(tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]); + Self::store_parallel_chunk(&mut buffer, 2, out2, out3); } } @@ -3063,160 +2942,128 @@ impl WasmSimdF32Butterfly24 { // pub struct WasmSimdF64Butterfly24 { - direction: FftDirection, + bf4: WasmSimdF64Butterfly4, bf6: WasmSimdF64Butterfly6, - bf12: WasmSimdF64Butterfly12, - rotate90: Rotate90F64, twiddle1: v128, twiddle2: v128, twiddle4: v128, twiddle5: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle4c: v128, - twiddle5c: v128, + twiddle8: v128, + twiddle10: v128, } boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly24, 24, - |this: &WasmSimdF64Butterfly24<_>| { this.direction } + |this: &WasmSimdF64Butterfly24<_>| { this.bf4.direction } ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF64Butterfly24, 24, - |this: &WasmSimdF64Butterfly24<_>| this.direction + |this: &WasmSimdF64Butterfly24<_>| this.bf4.direction ); impl WasmSimdF64Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let twiddle1 = twiddles::compute_twiddle(1, 24, direction); - let twiddle2 = twiddles::compute_twiddle(2, 24, direction); - let twiddle4 = twiddles::compute_twiddle(4, 24, direction); - let twiddle5 = twiddles::compute_twiddle(5, 24, direction); + let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); + let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + unsafe { Self { - direction, + bf4: WasmSimdF64Butterfly4::new(direction), bf6: WasmSimdF64Butterfly6::new(direction), - bf12: WasmSimdF64Butterfly12::new(direction), - rotate90: Rotate90F64::new(direction == FftDirection::Inverse), - twiddle1: pack64(twiddle1), - twiddle2: pack64(twiddle2), - twiddle4: pack64(twiddle4), - twiddle5: pack64(twiddle5), - twiddle1c: pack64(twiddle1.conj()), - twiddle2c: pack64(twiddle2.conj()), - twiddle4c: pack64(twiddle4.conj()), - twiddle5c: pack64(twiddle5.conj()), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle4: pack_64(tw4), + twiddle5: pack_64(tw5), + twiddle8: pack_64(tw8), + twiddle10: pack_64(tw10), } } } #[inline(always)] - pub(crate) unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); - - let out = self.perform_fft_direct(values); + unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 6).0, + buffer.load_complex(i + 12).0, + buffer.load_complex(i + 18).0, + ] + }; - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); - } + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = mul_complex_f64(tmp1[1], self.twiddle1); + tmp1[2] = mul_complex_f64(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_45(tmp1[3]); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = mul_complex_f64(tmp2[1], self.twiddle2); + tmp2[2] = mul_complex_f64(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate(tmp2[3]); + + let mut tmp4 = self.bf4.perform_fft_direct(load(4)); + tmp4[1] = mul_complex_f64(tmp4[1], self.twiddle4); + tmp4[2] = mul_complex_f64(tmp4[2], self.twiddle8); + tmp4[3] = WasmVector::neg(WasmVector64(tmp4[3])).0; + + let mut tmp5 = self.bf4.perform_fft_direct(load(5)); + tmp5[1] = mul_complex_f64(tmp5[1], self.twiddle5); + tmp5[2] = mul_complex_f64(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_225(tmp5[3]); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = self.bf4.rotate.rotate_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_135(tmp3[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [v128; 6]| { + buffer.store_complex(WasmVector64(vectors[0]), i); + buffer.store_complex(WasmVector64(vectors[1]), i + 4); + buffer.store_complex(WasmVector64(vectors[2]), i + 8); + buffer.store_complex(WasmVector64(vectors[3]), i + 12); + buffer.store_complex(WasmVector64(vectors[4]), i + 16); + buffer.store_complex(WasmVector64(vectors[5]), i + 20); + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 24]) -> [v128; 24] { - // we're going to hardcode a step of split radix + // Size-6 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); + store(0, out0); - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf12.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); + let out1 = self + .bf6 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]]); + store(1, out1); - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = f64x2_add(vec, rotated); - f64x2_mul( - sum, - v128_load64_splat(&0.5f64.sqrt() as *const f64 as *const u64), - ) - }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = f64x2_sub(vec, rotated); - f64x2_mul( - sum, - v128_load64_splat(&0.5f64.sqrt() as *const f64 as *const u64), - ) - }; + let out2 = self + .bf6 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]]); + store(2, out2); - // step 3: apply twiddle factors - odds1[1] = mul_complex_f64(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f64(odds3[1], self.twiddle1c); - - odds1[2] = mul_complex_f64(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f64(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = mul_complex_f64(odds1[4], self.twiddle4); - odds3[4] = mul_complex_f64(odds3[4], self.twiddle4c); - - odds1[5] = mul_complex_f64(odds1[5], self.twiddle5); - odds3[5] = mul_complex_f64(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f64x2_add(evens[0], temp0[0]), - f64x2_add(evens[1], temp1[0]), - f64x2_add(evens[2], temp2[0]), - f64x2_add(evens[3], temp3[0]), - f64x2_add(evens[4], temp4[0]), - f64x2_add(evens[5], temp5[0]), - f64x2_add(evens[6], temp0[1]), - f64x2_add(evens[7], temp1[1]), - f64x2_add(evens[8], temp2[1]), - f64x2_add(evens[9], temp3[1]), - f64x2_add(evens[10], temp4[1]), - f64x2_add(evens[11], temp5[1]), - f64x2_sub(evens[0], temp0[0]), - f64x2_sub(evens[1], temp1[0]), - f64x2_sub(evens[2], temp2[0]), - f64x2_sub(evens[3], temp3[0]), - f64x2_sub(evens[4], temp4[0]), - f64x2_sub(evens[5], temp5[0]), - f64x2_sub(evens[6], temp0[1]), - f64x2_sub(evens[7], temp1[1]), - f64x2_sub(evens[8], temp2[1]), - f64x2_sub(evens[9], temp3[1]), - f64x2_sub(evens[10], temp4[1]), - f64x2_sub(evens[11], temp5[1]), - ] + let out3 = self + .bf6 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]]); + store(3, out3); } } @@ -3228,55 +3075,37 @@ impl WasmSimdF64Butterfly24 { // pub struct WasmSimdF32Butterfly32 { - direction: FftDirection, bf8: WasmSimdF32Butterfly8, - bf16: WasmSimdF32Butterfly16, - rotate90: Rotate90F32, - twiddle01: v128, - twiddle23: v128, - twiddle45: v128, - twiddle67: v128, - twiddle01conj: v128, - twiddle23conj: v128, - twiddle45conj: v128, - twiddle67conj: v128, + twiddles_packed: [v128; 12], twiddle1: v128, twiddle2: v128, twiddle3: v128, - twiddle4: v128, twiddle5: v128, twiddle6: v128, twiddle7: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle3c: v128, - twiddle4c: v128, - twiddle5c: v128, - twiddle6c: v128, - twiddle7c: v128, + twiddle9: v128, + twiddle10: v128, + twiddle14: v128, + twiddle15: v128, + twiddle18: v128, + twiddle21: v128, } boilerplate_fft_wasm_simd_f32_butterfly!( WasmSimdF32Butterfly32, 32, - |this: &WasmSimdF32Butterfly32<_>| this.direction + |this: &WasmSimdF32Butterfly32<_>| this.bf8.bf4.direction ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF32Butterfly32, 32, - |this: &WasmSimdF32Butterfly32<_>| this.direction + |this: &WasmSimdF32Butterfly32<_>| this.bf8.bf4.direction ); impl WasmSimdF32Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = WasmSimdF32Butterfly8::new(direction); - let bf16 = WasmSimdF32Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); @@ -3284,261 +3113,226 @@ impl WasmSimdF32Butterfly32 { let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); - let twiddle01 = f32x4(1.0, 0.0, tw1.re, tw1.im); - let twiddle23 = f32x4(tw2.re, tw2.im, tw3.re, tw3.im); - let twiddle45 = f32x4(tw4.re, tw4.im, tw5.re, tw5.im); - let twiddle67 = f32x4(tw6.re, tw6.im, tw7.re, tw7.im); - let twiddle01conj = f32x4(1.0, 0.0, tw1.re, -tw1.im); - let twiddle23conj = f32x4(tw2.re, -tw2.im, tw3.re, -tw3.im); - let twiddle45conj = f32x4(tw4.re, -tw4.im, tw5.re, -tw5.im); - let twiddle67conj = f32x4(tw6.re, -tw6.im, tw7.re, -tw7.im); - let twiddle1 = f32x4(tw1.re, tw1.im, tw1.re, tw1.im); - let twiddle2 = f32x4(tw2.re, tw2.im, tw2.re, tw2.im); - let twiddle3 = f32x4(tw3.re, tw3.im, tw3.re, tw3.im); - let twiddle4 = f32x4(tw4.re, tw4.im, tw4.re, tw4.im); - let twiddle5 = f32x4(tw5.re, tw5.im, tw5.re, tw5.im); - let twiddle6 = f32x4(tw6.re, tw6.im, tw6.re, tw6.im); - let twiddle7 = f32x4(tw7.re, tw7.im, tw7.re, tw7.im); - let twiddle1c = f32x4(tw1.re, -tw1.im, tw1.re, -tw1.im); - let twiddle2c = f32x4(tw2.re, -tw2.im, tw2.re, -tw2.im); - let twiddle3c = f32x4(tw3.re, -tw3.im, tw3.re, -tw3.im); - let twiddle4c = f32x4(tw4.re, -tw4.im, tw4.re, -tw4.im); - let twiddle5c = f32x4(tw5.re, -tw5.im, tw5.re, -tw5.im); - let twiddle6c = f32x4(tw6.re, -tw6.im, tw6.re, -tw6.im); - let twiddle7c = f32x4(tw7.re, -tw7.im, tw7.re, -tw7.im); - Self { - direction, - bf8, - bf16, - rotate90, - twiddle01, - twiddle23, - twiddle45, - twiddle67, - twiddle01conj, - twiddle23conj, - twiddle45conj, - twiddle67conj, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + let tw8: Complex = twiddles::compute_twiddle(8, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); + unsafe { + Self { + bf8: WasmSimdF32Butterfly8::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + pack_32(tw6, tw7), + pack_32(tw12, tw14), + pack_32(tw18, tw21), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle5: pack_32(tw5, tw5), + twiddle6: pack_32(tw6, tw6), + twiddle7: pack_32(tw7, tw7), + twiddle9: pack_32(tw9, tw9), + twiddle10: pack_32(tw10, tw10), + twiddle14: pack_32(tw14, tw14), + twiddle15: pack_32(tw15, tw15), + twiddle18: pack_32(tw18, tw18), + twiddle21: pack_32(tw21, tw21), + } } } #[inline(always)] - unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 }); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); + unsafe fn load_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [v128; 4] { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 8).0, + buffer.load_complex(i + 16).0, + buffer.load_complex(i + 24).0, + ] } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl WasmSimdArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}); - - let values = interleave_complex_f32!(input_packed, 16, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }); + unsafe fn store_chunk(buffer: &mut impl WasmSimdArrayMut, i: usize, vectors: [v128; 8]) { + buffer.store_complex(WasmVector32(vectors[0]), i + 0); + buffer.store_complex(WasmVector32(vectors[1]), i + 4); + buffer.store_complex(WasmVector32(vectors[2]), i + 8); + buffer.store_complex(WasmVector32(vectors[3]), i + 12); + buffer.store_complex(WasmVector32(vectors[4]), i + 16); + buffer.store_complex(WasmVector32(vectors[5]), i + 20); + buffer.store_complex(WasmVector32(vectors[6]), i + 24); + buffer.store_complex(WasmVector32(vectors[7]), i + 28); } #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 16]) -> [v128; 16] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - let in2426 = extract_lo_lo_f32(input[12], input[13]); - let in2830 = extract_lo_lo_f32(input[14], input[15]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - let in2529 = extract_hi_hi_f32(input[12], input[14]); - - let in3103 = extract_hi_hi_f32(input[15], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - let in2327 = extract_hi_hi_f32(input[11], input[13]); - - let in_evens = [ - in0002, in0406, in0810, in1214, in1618, in2022, in2426, in2830, - ]; - - // step 2: column FFTs - let evens = self.bf16.perform_fft_direct(in_evens); - let mut odds1 = self + unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp0 = self .bf8 - .perform_fft_direct([in0105, in0913, in1721, in2529]); - let mut odds3 = self + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 0)); + tmp0[1] = mul_complex_f32(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = mul_complex_f32(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = mul_complex_f32(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self .bf8 - .perform_fft_direct([in3103, in0711, in1519, in2327]); - - // step 3: apply twiddle factors - odds1[0] = mul_complex_f32(odds1[0], self.twiddle01); - odds3[0] = mul_complex_f32(odds3[0], self.twiddle01conj); - - odds1[1] = mul_complex_f32(odds1[1], self.twiddle23); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle23conj); - - odds1[2] = mul_complex_f32(odds1[2], self.twiddle45); - odds3[2] = mul_complex_f32(odds3[2], self.twiddle45conj); - - odds1[3] = mul_complex_f32(odds1[3], self.twiddle67); - odds3[3] = mul_complex_f32(odds3[3], self.twiddle67conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 2)); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = mul_complex_f32(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self + .bf8 + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 4)); + tmp2[1] = mul_complex_f32(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = mul_complex_f32(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = mul_complex_f32(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid12, mid13] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp3 = self + .bf8 + .bf4 + .perform_parallel_fft_direct(Self::load_chunk(&buffer, 6)); + tmp3[1] = mul_complex_f32(tmp3[1], self.twiddles_packed[9]); + tmp3[2] = mul_complex_f32(tmp3[2], self.twiddles_packed[10]); + tmp3[3] = mul_complex_f32(tmp3[3], self.twiddles_packed[11]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp3[0], tmp3[1]); + let [mid14, mid15] = transpose_complex_2x2_f32(tmp3[2], tmp3[3]); + + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf8 + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); + Self::store_chunk(&mut buffer, 0, out0); - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); + let out1 = self + .bf8 + .perform_parallel_fft_direct([mid8, mid9, mid10, mid11, mid12, mid13, mid14, mid15]); + Self::store_chunk(&mut buffer, 2, out1); + } - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp2[0]), - f32x4_add(evens[3], temp3[0]), - f32x4_add(evens[4], temp0[1]), - f32x4_add(evens[5], temp1[1]), - f32x4_add(evens[6], temp2[1]), - f32x4_add(evens[7], temp3[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp2[0]), - f32x4_sub(evens[3], temp3[0]), - f32x4_sub(evens[4], temp0[1]), - f32x4_sub(evens[5], temp1[1]), - f32x4_sub(evens[6], temp2[1]), - f32x4_sub(evens[7], temp3[1]), - ] + #[inline(always)] + unsafe fn load_parallel_chunk(buffer: &impl WasmSimdArrayMut, i: usize) -> [[v128; 4]; 2] { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0).0, buffer.load_complex(i + 32).0); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8).0, buffer.load_complex(i + 40).0); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 16).0, buffer.load_complex(i + 48).0); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 24).0, buffer.load_complex(i + 56).0); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct(&self, input: [v128; 32]) -> [v128; 32] { - // we're going to hardcode a step of split radix + unsafe fn store_parallel_chunk( + buffer: &mut impl WasmSimdArrayMut, + i: usize, + values_a: [v128; 8], + values_b: [v128; 8], + ) { + for n in 0..8 { + let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); + buffer.store_complex(WasmVector32(a), i + n * 4); + buffer.store_complex(WasmVector32(b), i + n * 4 + 32); + } + } - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], - ]); - let mut odds1 = self.bf8.perform_parallel_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], + #[inline(always)] + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl WasmSimdArrayMut, + ) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in0, in1] = Self::load_parallel_chunk(&buffer, 0); + let tmp0 = self.bf8.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = mul_complex_f32(tmp1[1], self.twiddle1); + tmp1[2] = mul_complex_f32(tmp1[2], self.twiddle2); + tmp1[3] = mul_complex_f32(tmp1[3], self.twiddle3); + + let [in2, in3] = Self::load_parallel_chunk(&buffer, 2); + let mut tmp2 = self.bf8.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf8.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = mul_complex_f32(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_both_45(tmp2[2]); + tmp2[3] = mul_complex_f32(tmp2[3], self.twiddle6); + tmp3[1] = mul_complex_f32(tmp3[1], self.twiddle3); + tmp3[2] = mul_complex_f32(tmp3[2], self.twiddle6); + tmp3[3] = mul_complex_f32(tmp3[3], self.twiddle9); + + let [in4, in5] = Self::load_parallel_chunk(&buffer, 4); + let mut tmp4 = self.bf8.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf8.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = self.bf8.bf4.rotate.rotate_both_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate_both(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_both_135(tmp4[3]); + tmp5[1] = mul_complex_f32(tmp5[1], self.twiddle5); + tmp5[2] = mul_complex_f32(tmp5[2], self.twiddle10); + tmp5[3] = mul_complex_f32(tmp5[3], self.twiddle15); + + let [in6, in7] = Self::load_parallel_chunk(&buffer, 6); + let mut tmp6 = self.bf8.bf4.perform_parallel_fft_direct(in6); + let mut tmp7 = self.bf8.bf4.perform_parallel_fft_direct(in7); + tmp6[1] = mul_complex_f32(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_both_135(tmp6[2]); + tmp6[3] = mul_complex_f32(tmp6[3], self.twiddle18); + tmp7[1] = mul_complex_f32(tmp7[1], self.twiddle7); + tmp7[2] = mul_complex_f32(tmp7[2], self.twiddle14); + tmp7[3] = mul_complex_f32(tmp7[3], self.twiddle21); + + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self.bf8.perform_parallel_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], ]); - let mut odds3 = self.bf8.perform_parallel_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], + let out1 = self.bf8.perform_parallel_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], ]); + Self::store_parallel_chunk(&mut buffer, 0, out0, out1); - // step 3: apply twiddle factors - odds1[1] = mul_complex_f32(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f32(odds3[1], self.twiddle1c); - - odds1[2] = mul_complex_f32(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f32(odds3[2], self.twiddle2c); - - odds1[3] = mul_complex_f32(odds1[3], self.twiddle3); - odds3[3] = mul_complex_f32(odds3[3], self.twiddle3c); - - odds1[4] = mul_complex_f32(odds1[4], self.twiddle4); - odds3[4] = mul_complex_f32(odds3[4], self.twiddle4c); - - odds1[5] = mul_complex_f32(odds1[5], self.twiddle5); - odds3[5] = mul_complex_f32(odds3[5], self.twiddle5c); - - odds1[6] = mul_complex_f32(odds1[6], self.twiddle6); - odds3[6] = mul_complex_f32(odds3[6], self.twiddle6c); - - odds1[7] = mul_complex_f32(odds1[7], self.twiddle7); - odds3[7] = mul_complex_f32(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - let mut temp6 = parallel_fft2_interleaved_f32(odds1[6], odds3[6]); - let mut temp7 = parallel_fft2_interleaved_f32(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - temp6[1] = self.rotate90.rotate_both(temp6[1]); - temp7[1] = self.rotate90.rotate_both(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f32x4_add(evens[0], temp0[0]), - f32x4_add(evens[1], temp1[0]), - f32x4_add(evens[2], temp2[0]), - f32x4_add(evens[3], temp3[0]), - f32x4_add(evens[4], temp4[0]), - f32x4_add(evens[5], temp5[0]), - f32x4_add(evens[6], temp6[0]), - f32x4_add(evens[7], temp7[0]), - f32x4_add(evens[8], temp0[1]), - f32x4_add(evens[9], temp1[1]), - f32x4_add(evens[10], temp2[1]), - f32x4_add(evens[11], temp3[1]), - f32x4_add(evens[12], temp4[1]), - f32x4_add(evens[13], temp5[1]), - f32x4_add(evens[14], temp6[1]), - f32x4_add(evens[15], temp7[1]), - f32x4_sub(evens[0], temp0[0]), - f32x4_sub(evens[1], temp1[0]), - f32x4_sub(evens[2], temp2[0]), - f32x4_sub(evens[3], temp3[0]), - f32x4_sub(evens[4], temp4[0]), - f32x4_sub(evens[5], temp5[0]), - f32x4_sub(evens[6], temp6[0]), - f32x4_sub(evens[7], temp7[0]), - f32x4_sub(evens[8], temp0[1]), - f32x4_sub(evens[9], temp1[1]), - f32x4_sub(evens[10], temp2[1]), - f32x4_sub(evens[11], temp3[1]), - f32x4_sub(evens[12], temp4[1]), - f32x4_sub(evens[13], temp5[1]), - f32x4_sub(evens[14], temp6[1]), - f32x4_sub(evens[15], temp7[1]), - ] + let out2 = self.bf8.perform_parallel_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], + ]); + let out3 = self.bf8.perform_parallel_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); + Self::store_parallel_chunk(&mut buffer, 2, out2, out3); } } @@ -3550,251 +3344,156 @@ impl WasmSimdF32Butterfly32 { // pub struct WasmSimdF64Butterfly32 { - direction: FftDirection, bf8: WasmSimdF64Butterfly8, - bf16: WasmSimdF64Butterfly16, - rotate90: Rotate90F64, twiddle1: v128, twiddle2: v128, twiddle3: v128, - twiddle4: v128, twiddle5: v128, twiddle6: v128, twiddle7: v128, - twiddle1c: v128, - twiddle2c: v128, - twiddle3c: v128, - twiddle4c: v128, - twiddle5c: v128, - twiddle6c: v128, - twiddle7c: v128, + twiddle9: v128, + twiddle10: v128, + twiddle14: v128, + twiddle15: v128, + twiddle18: v128, + twiddle21: v128, } boilerplate_fft_wasm_simd_f64_butterfly!( WasmSimdF64Butterfly32, 32, - |this: &WasmSimdF64Butterfly32<_>| this.direction + |this: &WasmSimdF64Butterfly32<_>| this.bf8.bf4.direction ); boilerplate_fft_wasm_simd_common_butterfly!( WasmSimdF64Butterfly32, 32, - |this: &WasmSimdF64Butterfly32<_>| this.direction + |this: &WasmSimdF64Butterfly32<_>| this.bf8.bf4.direction ); impl WasmSimdF64Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = WasmSimdF64Butterfly8::new(direction); - let bf16 = WasmSimdF64Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = unsafe { - v128_load( - &twiddles::compute_twiddle::(1, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle2 = unsafe { - v128_load( - &twiddles::compute_twiddle::(2, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle3 = unsafe { - v128_load( - &twiddles::compute_twiddle::(3, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle4 = unsafe { - v128_load( - &twiddles::compute_twiddle::(4, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle5 = unsafe { - v128_load( - &twiddles::compute_twiddle::(5, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle6 = unsafe { - v128_load( - &twiddles::compute_twiddle::(6, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle7 = unsafe { - v128_load( - &twiddles::compute_twiddle::(7, 32, direction) as *const _ as *const v128, - ) - }; - let twiddle1c = unsafe { - v128_load( - &twiddles::compute_twiddle::(1, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle2c = unsafe { - v128_load( - &twiddles::compute_twiddle::(2, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle3c = unsafe { - v128_load( - &twiddles::compute_twiddle::(3, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle4c = unsafe { - v128_load( - &twiddles::compute_twiddle::(4, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle5c = unsafe { - v128_load( - &twiddles::compute_twiddle::(5, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle6c = unsafe { - v128_load( - &twiddles::compute_twiddle::(6, 32, direction).conj() as *const _ - as *const v128, - ) - }; - let twiddle7c = unsafe { - v128_load( - &twiddles::compute_twiddle::(7, 32, direction).conj() as *const _ - as *const v128, - ) - }; + let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); + let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); - Self { - direction, - bf8, - bf16, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + unsafe { + Self { + bf8: WasmSimdF64Butterfly8::new(direction), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle3: pack_64(tw3), + twiddle5: pack_64(tw5), + twiddle6: pack_64(tw6), + twiddle7: pack_64(tw7), + twiddle9: pack_64(tw9), + twiddle10: pack_64(tw10), + twiddle14: pack_64(tw14), + twiddle15: pack_64(tw15), + twiddle18: pack_64(tw18), + twiddle21: pack_64(tw21), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl WasmSimdArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - } + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i).0, + buffer.load_complex(i + 8).0, + buffer.load_complex(i + 16).0, + buffer.load_complex(i + 24).0, + ] + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [v128; 32]) -> [v128; 32] { - // we're going to hardcode a step of split radix + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf8.bf4.perform_fft_direct(load(1)); + tmp1[1] = mul_complex_f64(tmp1[1], self.twiddle1); + tmp1[2] = mul_complex_f64(tmp1[2], self.twiddle2); + tmp1[3] = mul_complex_f64(tmp1[3], self.twiddle3); + + let mut tmp2 = self.bf8.bf4.perform_fft_direct(load(2)); + tmp2[1] = mul_complex_f64(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_45(tmp2[2]); + tmp2[3] = mul_complex_f64(tmp2[3], self.twiddle6); + + let mut tmp3 = self.bf8.bf4.perform_fft_direct(load(3)); + tmp3[1] = mul_complex_f64(tmp3[1], self.twiddle3); + tmp3[2] = mul_complex_f64(tmp3[2], self.twiddle6); + tmp3[3] = mul_complex_f64(tmp3[3], self.twiddle9); + + let mut tmp5 = self.bf8.bf4.perform_fft_direct(load(5)); + tmp5[1] = mul_complex_f64(tmp5[1], self.twiddle5); + tmp5[2] = mul_complex_f64(tmp5[2], self.twiddle10); + tmp5[3] = mul_complex_f64(tmp5[3], self.twiddle15); + + let mut tmp6 = self.bf8.bf4.perform_fft_direct(load(6)); + tmp6[1] = mul_complex_f64(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_135(tmp6[2]); + tmp6[3] = mul_complex_f64(tmp6[3], self.twiddle18); + + let mut tmp7 = self.bf8.bf4.perform_fft_direct(load(7)); + tmp7[1] = mul_complex_f64(tmp7[1], self.twiddle7); + tmp7[2] = mul_complex_f64(tmp7[2], self.twiddle14); + tmp7[3] = mul_complex_f64(tmp7[3], self.twiddle21); + + let mut tmp4 = self.bf8.bf4.perform_fft_direct(load(4)); + tmp4[1] = self.bf8.bf4.rotate.rotate_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_135(tmp4[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf8.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [v128; 8]| { + buffer.store_complex(WasmVector64(vectors[0]), i); + buffer.store_complex(WasmVector64(vectors[1]), i + 4); + buffer.store_complex(WasmVector64(vectors[2]), i + 8); + buffer.store_complex(WasmVector64(vectors[3]), i + 12); + buffer.store_complex(WasmVector64(vectors[4]), i + 16); + buffer.store_complex(WasmVector64(vectors[5]), i + 20); + buffer.store_complex(WasmVector64(vectors[6]), i + 24); + buffer.store_complex(WasmVector64(vectors[7]), i + 28); + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], + // Size-8 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self.bf8.perform_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], ]); - let mut odds1 = self.bf8.perform_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], + store(0, out0); + + let out1 = self.bf8.perform_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], ]); - let mut odds3 = self.bf8.perform_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], + store(1, out1); + + let out2 = self.bf8.perform_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], ]); + store(2, out2); - // step 3: apply twiddle factors - odds1[1] = mul_complex_f64(odds1[1], self.twiddle1); - odds3[1] = mul_complex_f64(odds3[1], self.twiddle1c); - - odds1[2] = mul_complex_f64(odds1[2], self.twiddle2); - odds3[2] = mul_complex_f64(odds3[2], self.twiddle2c); - - odds1[3] = mul_complex_f64(odds1[3], self.twiddle3); - odds3[3] = mul_complex_f64(odds3[3], self.twiddle3c); - - odds1[4] = mul_complex_f64(odds1[4], self.twiddle4); - odds3[4] = mul_complex_f64(odds3[4], self.twiddle4c); - - odds1[5] = mul_complex_f64(odds1[5], self.twiddle5); - odds3[5] = mul_complex_f64(odds3[5], self.twiddle5c); - - odds1[6] = mul_complex_f64(odds1[6], self.twiddle6); - odds3[6] = mul_complex_f64(odds3[6], self.twiddle6c); - - odds1[7] = mul_complex_f64(odds1[7], self.twiddle7); - odds3[7] = mul_complex_f64(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - let mut temp6 = solo_fft2_f64(odds1[6], odds3[6]); - let mut temp7 = solo_fft2_f64(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - temp6[1] = self.rotate90.rotate(temp6[1]); - temp7[1] = self.rotate90.rotate(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - f64x2_add(evens[0], temp0[0]), - f64x2_add(evens[1], temp1[0]), - f64x2_add(evens[2], temp2[0]), - f64x2_add(evens[3], temp3[0]), - f64x2_add(evens[4], temp4[0]), - f64x2_add(evens[5], temp5[0]), - f64x2_add(evens[6], temp6[0]), - f64x2_add(evens[7], temp7[0]), - f64x2_add(evens[8], temp0[1]), - f64x2_add(evens[9], temp1[1]), - f64x2_add(evens[10], temp2[1]), - f64x2_add(evens[11], temp3[1]), - f64x2_add(evens[12], temp4[1]), - f64x2_add(evens[13], temp5[1]), - f64x2_add(evens[14], temp6[1]), - f64x2_add(evens[15], temp7[1]), - f64x2_sub(evens[0], temp0[0]), - f64x2_sub(evens[1], temp1[0]), - f64x2_sub(evens[2], temp2[0]), - f64x2_sub(evens[3], temp3[0]), - f64x2_sub(evens[4], temp4[0]), - f64x2_sub(evens[5], temp5[0]), - f64x2_sub(evens[6], temp6[0]), - f64x2_sub(evens[7], temp7[0]), - f64x2_sub(evens[8], temp0[1]), - f64x2_sub(evens[9], temp1[1]), - f64x2_sub(evens[10], temp2[1]), - f64x2_sub(evens[11], temp3[1]), - f64x2_sub(evens[12], temp4[1]), - f64x2_sub(evens[13], temp5[1]), - f64x2_sub(evens[14], temp6[1]), - f64x2_sub(evens[15], temp7[1]), - ] + let out3 = self.bf8.perform_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); + store(3, out3); } } diff --git a/src/wasm_simd/wasm_simd_utils.rs b/src/wasm_simd/wasm_simd_utils.rs index 93a8eb9f..555c6ad7 100644 --- a/src/wasm_simd/wasm_simd_utils.rs +++ b/src/wasm_simd/wasm_simd_utils.rs @@ -37,6 +37,27 @@ impl Rotate90F32 { pub unsafe fn rotate_both(&self, values: v128) -> v128 { v128_xor(u32x4_shuffle::<1, 0, 3, 2>(values, values), self.sign_both) } + + #[inline(always)] + pub unsafe fn rotate_both_45(&self, values: v128) -> v128 { + let rotated = self.rotate_both(values); + let sum = f32x4_add(rotated, values); + f32x4_mul(sum, f32x4_splat(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_135(&self, values: v128) -> v128 { + let rotated = self.rotate_both(values); + let diff = f32x4_sub(rotated, values); + f32x4_mul(diff, f32x4_splat(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_225(&self, values: v128) -> v128 { + let rotated = self.rotate_both(values); + let diff = f32x4_add(rotated, values); + f32x4_mul(diff, f32x4_splat(-(0.5f32.sqrt()))) + } } /// Pack low (1st) complex @@ -165,6 +186,27 @@ impl Rotate90F64 { pub unsafe fn rotate(&self, values: v128) -> v128 { v128_xor(u64x2_shuffle::<1, 0>(values, values), self.sign) } + + #[inline(always)] + pub unsafe fn rotate_45(&self, values: v128) -> v128 { + let rotated = self.rotate(values); + let sum = f64x2_add(rotated, values); + f64x2_mul(sum, f64x2_splat(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_135(&self, values: v128) -> v128 { + let rotated = self.rotate(values); + let diff = f64x2_sub(rotated, values); + f64x2_mul(diff, f64x2_splat(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_225(&self, values: v128) -> v128 { + let rotated = self.rotate(values); + let diff = f64x2_add(rotated, values); + f64x2_mul(diff, f64x2_splat(-(0.5f64.sqrt()))) + } } #[inline(always)] diff --git a/src/wasm_simd/wasm_simd_vector.rs b/src/wasm_simd/wasm_simd_vector.rs index af25ea76..bb1ead89 100644 --- a/src/wasm_simd/wasm_simd_vector.rs +++ b/src/wasm_simd/wasm_simd_vector.rs @@ -158,6 +158,9 @@ pub trait WasmVector: Copy + Debug + Send + Sync { unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); + // math ops + unsafe fn neg(a: Self) -> Self; + /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times. /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector unsafe fn make_mixedradix_twiddle_chunk( @@ -222,6 +225,11 @@ impl WasmVector for WasmVector32 { v128_store64_lane::<1>(data.0, ptr as *mut u64); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + Self(f32x4_neg(a.0)) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, @@ -332,6 +340,11 @@ impl WasmVector for WasmVector64 { unimplemented!("Impossible to do a partial store of complex f64's"); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + Self(f64x2_neg(a.0)) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, From 7c0e11ee0e35bc96bfeb23092db1b1478fe3dc50 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 18 Feb 2024 22:28:26 -0800 Subject: [PATCH 06/13] cargo fmt --- src/sse/sse_butterflies.rs | 335 ++++++++++++++++++++++++------------- src/sse/sse_utils.rs | 2 +- 2 files changed, 218 insertions(+), 119 deletions(-) diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 72dadc31..64128c5b 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -23,7 +23,6 @@ unsafe fn pack_64(a: Complex) -> __m128d { _mm_set_pd(a.im, a.re) } - #[allow(unused)] macro_rules! boilerplate_fft_sse_f32_butterfly { ($struct_name:ident, $len:expr, $direction_fn:expr) => { @@ -1437,9 +1436,11 @@ pub struct SseF64Butterfly8 { } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this - .bf4.direction); + .bf4 + .direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly8, 8, |this: &SseF64Butterfly8<_>| this - .bf4.direction); + .bf4 + .direction); impl SseF64Butterfly8 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -2293,12 +2294,11 @@ pub struct SseF32Butterfly16 { twiddle9: __m128, } -boilerplate_fft_sse_f32_butterfly_noparallel!( - SseF32Butterfly16, - 16, - |this: &SseF32Butterfly16<_>| this.bf4.direction -); -boilerplate_fft_sse_common_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this.bf4 +boilerplate_fft_sse_f32_butterfly_noparallel!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16< + _, +>| this.bf4.direction); +boilerplate_fft_sse_common_butterfly!(SseF32Butterfly16, 16, |this: &SseF32Butterfly16<_>| this + .bf4 .direction); impl SseF32Butterfly16 { pub fn new(direction: FftDirection) -> Self { @@ -2310,7 +2310,7 @@ impl SseF32Butterfly16 { let tw4: Complex = twiddles::compute_twiddle(4, 16, direction); let tw6: Complex = twiddles::compute_twiddle(6, 16, direction); let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); - + unsafe { Self { bf4: SseF32Butterfly4::new(direction), @@ -2337,15 +2337,17 @@ impl SseF32Butterfly16 { // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 4), - buffer.load_complex(i + 8), - buffer.load_complex(i + 12), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ] + }; // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, and transpose let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); @@ -2370,10 +2372,14 @@ impl SseF32Butterfly16 { buffer.store_complex(vectors[3], i + 12); }; // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf4.perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); + let out0 = self + .bf4 + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); store(0, out0); - let out1 = self.bf4.perform_parallel_fft_direct([mid4, mid5, mid6, mid7]); + let out1 = self + .bf4 + .perform_parallel_fft_direct([mid4, mid5, mid6, mid7]); store(2, out1); } @@ -2384,14 +2390,18 @@ impl SseF32Butterfly16 { // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { - let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 16)); - let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 4), buffer.load_complex(i + 20)); - let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 24)); - let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 28)); + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 16)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 4), buffer.load_complex(i + 20)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 24)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 28)); [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; @@ -2418,17 +2428,25 @@ impl SseF32Butterfly16 { let mut store = |i, values_a: [__m128; 4], values_b: [__m128; 4]| { for n in 0..4 { let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); - buffer.store_complex(a, i + n*4); - buffer.store_complex(b, i + n*4 + 16); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 16); } }; // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf4.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); - let out1 = self.bf4.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + let out0 = self + .bf4 + .perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + let out1 = self + .bf4 + .perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); store(0, out0, out1); - let out2 = self.bf4.perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); - let out3 = self.bf4.perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + let out2 = self + .bf4 + .perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + let out3 = self + .bf4 + .perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); store(2, out2, out3); } } @@ -2447,9 +2465,11 @@ pub struct SseF64Butterfly16 { } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this - .bf4.direction); + .bf4 + .direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly16, 16, |this: &SseF64Butterfly16<_>| this - .bf4.direction); + .bf4 + .direction); impl SseF64Butterfly16 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -2457,7 +2477,7 @@ impl SseF64Butterfly16 { let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); - + unsafe { Self { bf4: SseF64Butterfly4::new(direction), @@ -2474,15 +2494,17 @@ impl SseF64Butterfly16 { // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 4), - buffer.load_complex(i + 8), - buffer.load_complex(i + 12), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ] + }; // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf4.perform_fft_direct(load(1)); @@ -2512,16 +2534,24 @@ impl SseF64Butterfly16 { }; // Size-4 FFTs down each of our transposed columns, storing them as soon as we're done with them - let out0 = self.bf4.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + let out0 = self + .bf4 + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); store(0, out0); - let out1 = self.bf4.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + let out1 = self + .bf4 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); store(1, out1); - let out2 = self.bf4.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + let out2 = self + .bf4 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); store(2, out2); - let out3 = self.bf4.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + let out3 = self + .bf4 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); store(3, out3); } } @@ -2549,7 +2579,8 @@ boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfl this.bf4.direction }); boilerplate_fft_sse_common_butterfly!(SseF32Butterfly24, 24, |this: &SseF32Butterfly24<_>| this - .bf4.direction); + .bf4 + .direction); impl SseF32Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -2597,15 +2628,17 @@ impl SseF32Butterfly24 { // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 6), - buffer.load_complex(i + 12), - buffer.load_complex(i + 18), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ] + }; // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, transpose let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); @@ -2640,10 +2673,14 @@ impl SseF32Butterfly24 { }; // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf6.perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); + let out0 = self + .bf6 + .perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); store(0, out0); - let out1 = self.bf6.perform_parallel_fft_direct(mid6, mid7, mid8, mid9, mid10, mid11); + let out1 = self + .bf6 + .perform_parallel_fft_direct(mid6, mid7, mid8, mid9, mid10, mid11); store(2, out1); } @@ -2653,14 +2690,18 @@ impl SseF32Butterfly24 { // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { - let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 24)); - let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 6), buffer.load_complex(i + 30)); - let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 36)); - let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 18), buffer.load_complex(i + 42)); + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 24)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 6), buffer.load_complex(i + 30)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 36)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 18), buffer.load_complex(i + 42)); [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; @@ -2696,18 +2737,26 @@ impl SseF32Butterfly24 { let mut store = |i, vectors_a: [__m128; 6], vectors_b: [__m128; 6]| { for n in 0..6 { let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); - buffer.store_complex(a, i + n*4); - buffer.store_complex(b, i + n*4 + 24); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 24); } }; // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf6.perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); - let out1 = self.bf6.perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); + let out0 = self + .bf6 + .perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); + let out1 = self + .bf6 + .perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); store(0, out0, out1); - let out2 = self.bf6.perform_parallel_fft_direct(tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]); - let out3 = self.bf6.perform_parallel_fft_direct(tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]); + let out2 = self + .bf6 + .perform_parallel_fft_direct(tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]); + let out3 = self + .bf6 + .perform_parallel_fft_direct(tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]); store(2, out2, out3); } } @@ -2734,7 +2783,8 @@ boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfl this.bf4.direction }); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly24, 24, |this: &SseF64Butterfly24<_>| this - .bf4.direction); + .bf4 + .direction); impl SseF64Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -2745,7 +2795,7 @@ impl SseF64Butterfly24 { let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); - + unsafe { Self { bf4: SseF64Butterfly4::new(direction), @@ -2766,15 +2816,17 @@ impl SseF64Butterfly24 { // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 6), - buffer.load_complex(i + 12), - buffer.load_complex(i + 18), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ] + }; // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf4.perform_fft_direct(load(1)); @@ -2786,7 +2838,7 @@ impl SseF64Butterfly24 { tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); tmp2[2] = SseVector::mul_complex(tmp2[2], self.twiddle4); tmp2[3] = self.bf4.rotate.rotate(tmp2[3]); - + let mut tmp4 = self.bf4.perform_fft_direct(load(4)); tmp4[1] = SseVector::mul_complex(tmp4[1], self.twiddle4); tmp4[2] = SseVector::mul_complex(tmp4[2], self.twiddle8); @@ -2816,16 +2868,24 @@ impl SseF64Butterfly24 { }; // Size-6 FFTs down each of our transposed columns, storing them as soon as we're done with them - let out0 = self.bf6.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); + let out0 = self + .bf6 + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); store(0, out0); - let out1 = self.bf6.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]]); + let out1 = self + .bf6 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]]); store(1, out1); - let out2 = self.bf6.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]]); + let out2 = self + .bf6 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]]); store(2, out2); - - let out3 = self.bf6.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]]); + + let out3 = self + .bf6 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]]); store(3, out3); } } @@ -2855,9 +2915,13 @@ pub struct SseF32Butterfly32 { } boilerplate_fft_sse_f32_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this - .bf8.bf4.direction); + .bf8 + .bf4 + .direction); boilerplate_fft_sse_common_butterfly!(SseF32Butterfly32, 32, |this: &SseF32Butterfly32<_>| this - .bf8.bf4.direction); + .bf8 + .bf4 + .direction); impl SseF32Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -2917,15 +2981,17 @@ impl SseF32Butterfly32 { // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 8), - buffer.load_complex(i + 16), - buffer.load_complex(i + 24), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ] + }; // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp0 = self.bf8.bf4.perform_parallel_fft_direct(load(0)); @@ -2969,10 +3035,14 @@ impl SseF32Butterfly32 { }; // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf8.perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); + let out0 = self + .bf8 + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); store(0, out0); - let out1 = self.bf8.perform_parallel_fft_direct([mid8, mid9, mid10, mid11, mid12, mid13, mid14, mid15]); + let out1 = self + .bf8 + .perform_parallel_fft_direct([mid8, mid9, mid10, mid11, mid12, mid13, mid14, mid15]); store(2, out1); } @@ -2982,14 +3052,18 @@ impl SseF32Butterfly32 { // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end let load = |i: usize| { - let [a0, a1] = transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 32)); - let [b0, b1] = transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 40)); - let [c0, c1] = transpose_complex_2x2_f32(buffer.load_complex(i + 16), buffer.load_complex(i + 48)); - let [d0, d1] = transpose_complex_2x2_f32(buffer.load_complex(i + 24), buffer.load_complex(i + 56)); + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 32)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 40)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 16), buffer.load_complex(i + 48)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 24), buffer.load_complex(i + 56)); [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; @@ -3035,18 +3109,26 @@ impl SseF32Butterfly32 { let mut store = |i, vectors_a: [__m128; 8], vectors_b: [__m128; 8]| { for n in 0..8 { let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); - buffer.store_complex(a, i + n*4); - buffer.store_complex(b, i + n*4 + 32); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 32); } }; // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them - let out0 = self.bf8.perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); - let out1 = self.bf8.perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1]]); + let out0 = self.bf8.perform_parallel_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], + ]); + let out1 = self.bf8.perform_parallel_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], + ]); store(0, out0, out1); - let out2 = self.bf8.perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2]]); - let out3 = self.bf8.perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3]]); + let out2 = self.bf8.perform_parallel_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], + ]); + let out3 = self.bf8.perform_parallel_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); store(2, out2, out3); } } @@ -3075,9 +3157,13 @@ pub struct SseF64Butterfly32 { } boilerplate_fft_sse_f64_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this - .bf8.bf4.direction); + .bf8 + .bf4 + .direction); boilerplate_fft_sse_common_butterfly!(SseF64Butterfly32, 32, |this: &SseF64Butterfly32<_>| this - .bf8.bf4.direction); + .bf8 + .bf4 + .direction); impl SseF64Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { @@ -3094,7 +3180,7 @@ impl SseF64Butterfly32 { let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); - + unsafe { Self { bf8: SseF64Butterfly8::new(direction), @@ -3120,15 +3206,17 @@ impl SseF64Butterfly32 { // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: // 1: Load data as late as possible, not upfront - // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // 2: Once we're working with a piece of data, make as much progress as possible before moving on // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column // 3: Store data as soon as we're finished with it, rather than waiting for the end - let load = |i| [ - buffer.load_complex(i), - buffer.load_complex(i + 8), - buffer.load_complex(i + 16), - buffer.load_complex(i + 24), - ]; + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ] + }; // For each column: load the data, apply our size-4 FFT, apply twiddle factors let mut tmp1 = self.bf8.bf4.perform_fft_direct(load(1)); @@ -3182,16 +3270,24 @@ impl SseF64Butterfly32 { }; // Size-8 FFTs down each of our transposed columns, storing them as soon as we're done with them - let out0 = self.bf8.perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0]]); + let out0 = self.bf8.perform_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], + ]); store(0, out0); - let out1 = self.bf8.perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1]]); + let out1 = self.bf8.perform_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], + ]); store(1, out1); - let out2 = self.bf8.perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2]]); + let out2 = self.bf8.perform_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], + ]); store(2, out2); - - let out3 = self.bf8.perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3]]); + + let out3 = self.bf8.perform_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); store(3, out3); } } @@ -3199,7 +3295,10 @@ impl SseF64Butterfly32 { #[cfg(test)] mod unit_tests { use super::*; - use crate::{algorithm::Dft, test_utils::{check_fft_algorithm, compare_vectors}}; + use crate::{ + algorithm::Dft, + test_utils::{check_fft_algorithm, compare_vectors}, + }; //the tests for all butterflies will be identical except for the identifiers used and size //so it's ideal for a macro diff --git a/src/sse/sse_utils.rs b/src/sse/sse_utils.rs index be4ac5f8..e1a93a21 100644 --- a/src/sse/sse_utils.rs +++ b/src/sse/sse_utils.rs @@ -63,7 +63,7 @@ impl Rotate90F32 { let temp = _mm_shuffle_ps(values, values, 0xB1); _mm_xor_ps(temp, self.sign_both) } - + #[inline(always)] pub unsafe fn rotate_both_45(&self, values: __m128) -> __m128 { let rotated = self.rotate_both(values); From 8ea6ad257864c7d2b9f00ee65f003c4d41aaef80 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sat, 24 Feb 2024 21:25:30 -0800 Subject: [PATCH 07/13] Optimized neon butterflies to be more cache-friendly --- src/neon/neon_butterflies.rs | 1918 ++++++++++++++-------------------- src/neon/neon_utils.rs | 42 + src/neon/neon_vector.rs | 13 + 3 files changed, 853 insertions(+), 1120 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 01effe0e..dabead51 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -15,11 +15,11 @@ use super::neon_utils::*; use super::neon_vector::{NeonArrayMut, NeonVector}; #[inline(always)] -unsafe fn pack32(a: Complex, b: Complex) -> float32x4_t { +unsafe fn pack_32(a: Complex, b: Complex) -> float32x4_t { vld1q_f32([a.re, a.im, b.re, b.im].as_ptr()) } #[inline(always)] -unsafe fn pack64(a: Complex) -> float64x2_t { +unsafe fn pack_64(a: Complex) -> float64x2_t { vld1q_f64([a.re, a.im].as_ptr()) } @@ -663,7 +663,7 @@ impl NeonF32Butterfly4 { let [value0ab, value1ab] = transpose_complex_2x2_f32(value01a, value01b); let [value2ab, value3ab] = transpose_complex_2x2_f32(value23a, value23b); - let out = self.perform_parallel_fft_direct(value0ab, value1ab, value2ab, value3ab); + let out = self.perform_parallel_fft_direct([value0ab, value1ab, value2ab, value3ab]); let [out0, out1] = transpose_complex_2x2_f32(out[0], out[1]); let [out2, out3] = transpose_complex_2x2_f32(out[2], out[3]); @@ -702,21 +702,15 @@ impl NeonF32Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct( - &self, - values0: float32x4_t, - values1: float32x4_t, - values2: float32x4_t, - values3: float32x4_t, - ) -> [float32x4_t; 4] { + pub(crate) unsafe fn perform_parallel_fft_direct(&self, values: [float32x4_t; 4]) -> [float32x4_t; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = parallel_fft2_interleaved_f32(values0, values2); - let mut temp1 = parallel_fft2_interleaved_f32(values1, values3); + let temp0 = parallel_fft2_interleaved_f32(values[0], values[2]); + let mut temp1 = parallel_fft2_interleaved_f32(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate_both(temp1[1]); @@ -773,7 +767,7 @@ impl NeonF64Butterfly4 { let value2 = buffer.load_complex(2); let value3 = buffer.load_complex(3); - let out = self.perform_fft_direct(value0, value1, value2, value3); + let out = self.perform_fft_direct([value0, value1, value2, value3]); buffer.store_complex(out[0], 0); buffer.store_complex(out[1], 1); @@ -782,21 +776,15 @@ impl NeonF64Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: float64x2_t, - value1: float64x2_t, - value2: float64x2_t, - value3: float64x2_t, - ) -> [float64x2_t; 4] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [float64x2_t; 4]) -> [float64x2_t; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm // step 1: transpose // and // step 2: column FFTs - let temp0 = solo_fft2_f64(value0, value2); - let mut temp1 = solo_fft2_f64(value1, value3); + let temp0 = solo_fft2_f64(values[0], values[2]); + let mut temp1 = solo_fft2_f64(values[1], values[3]); // step 3: apply twiddle factors (only one in this case, and it's either 0 + i or 0 - i) temp1[1] = self.rotate.rotate(temp1[1]); @@ -1246,7 +1234,7 @@ impl NeonF64Butterfly6 { let value4 = buffer.load_complex(4); let value5 = buffer.load_complex(5); - let out = self.perform_fft_direct(value0, value1, value2, value3, value4, value5); + let out = self.perform_fft_direct([value0, value1, value2, value3, value4, value5]); buffer.store_complex(out[0], 0); buffer.store_complex(out[1], 1); @@ -1257,20 +1245,12 @@ impl NeonF64Butterfly6 { } #[inline(always)] - pub(crate) unsafe fn perform_fft_direct( - &self, - value0: float64x2_t, - value1: float64x2_t, - value2: float64x2_t, - value3: float64x2_t, - value4: float64x2_t, - value5: float64x2_t, - ) -> [float64x2_t; 6] { + pub(crate) unsafe fn perform_fft_direct(&self, values: [float64x2_t; 6]) -> [float64x2_t; 6] { // Algorithm: 3x2 good-thomas // Size-3 FFTs down the columns of our reordered array - let mid0 = self.bf3.perform_fft_direct(value0, value2, value4); - let mid1 = self.bf3.perform_fft_direct(value3, value5, value1); + let mid0 = self.bf3.perform_fft_direct(values[0], values[2], values[4]); + let mid1 = self.bf3.perform_fft_direct(values[3], values[5], values[1]); // We normally would put twiddle factors right here, but since this is good-thomas algorithm, we don't need twiddle factors @@ -1390,10 +1370,10 @@ impl NeonF32Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_parallel_fft_direct(values[0], values[2], values[4], values[6]); + .perform_parallel_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_parallel_fft_direct(values[1], values[3], values[5], values[7]); + .perform_parallel_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors let val5b = self.rotate90.rotate_both(val47[1]); @@ -1473,10 +1453,10 @@ impl NeonF64Butterfly8 { // step 2: column FFTs let val03 = self .bf4 - .perform_fft_direct(values[0], values[2], values[4], values[6]); + .perform_fft_direct([values[0], values[2], values[4], values[6]]); let mut val47 = self .bf4 - .perform_fft_direct(values[1], values[3], values[5], values[7]); + .perform_fft_direct([values[1], values[3], values[5], values[7]]); // step 3: apply twiddle factors let val5b = self.rotate90.rotate(val47[1]); @@ -2001,13 +1981,13 @@ impl NeonF32Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_parallel_fft_direct(values[0], values[3], values[6], values[9]); + .perform_parallel_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_parallel_fft_direct(values[4], values[7], values[10], values[1]); + .perform_parallel_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_parallel_fft_direct(values[8], values[11], values[2], values[5]); + .perform_parallel_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2081,13 +2061,13 @@ impl NeonF64Butterfly12 { // Size-4 FFTs down the columns of our reordered array let mid0 = self .bf4 - .perform_fft_direct(values[0], values[3], values[6], values[9]); + .perform_fft_direct([values[0], values[3], values[6], values[9]]); let mid1 = self .bf4 - .perform_fft_direct(values[4], values[7], values[10], values[1]); + .perform_fft_direct([values[4], values[7], values[10], values[1]]); let mid2 = self .bf4 - .perform_fft_direct(values[8], values[11], values[2], values[5]); + .perform_fft_direct([values[8], values[11], values[2], values[5]]); // Since this is good-thomas algorithm, we don't need twiddle factors @@ -2321,202 +2301,169 @@ impl NeonF64Butterfly15 { // pub struct NeonF32Butterfly16 { - direction: FftDirection, bf4: NeonF32Butterfly4, - bf8: NeonF32Butterfly8, - rotate90: Rotate90F32, - twiddle01: float32x4_t, - twiddle23: float32x4_t, - twiddle01conj: float32x4_t, - twiddle23conj: float32x4_t, + twiddles_packed: [float32x4_t; 6], twiddle1: float32x4_t, twiddle2: float32x4_t, twiddle3: float32x4_t, - twiddle1c: float32x4_t, - twiddle2c: float32x4_t, - twiddle3c: float32x4_t, + twiddle6: float32x4_t, + twiddle9: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this - .direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16< + _, +>| this.bf4.direction); boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this + .bf4 .direction); impl NeonF32Butterfly16 { - #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = NeonF32Butterfly8::new(direction); - let bf4 = NeonF32Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); let tw2: Complex = twiddles::compute_twiddle(2, 16, direction); let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); - let twiddle01 = unsafe { vld1q_f32([1.0, 0.0, tw1.re, tw1.im].as_ptr()) }; - let twiddle23 = unsafe { vld1q_f32([tw2.re, tw2.im, tw3.re, tw3.im].as_ptr()) }; - let twiddle01conj = unsafe { vld1q_f32([1.0, 0.0, tw1.re, -tw1.im].as_ptr()) }; - let twiddle23conj = unsafe { vld1q_f32([tw2.re, -tw2.im, tw3.re, -tw3.im].as_ptr()) }; - let twiddle1 = unsafe { vld1q_f32([tw1.re, tw1.im, tw1.re, tw1.im].as_ptr()) }; - let twiddle2 = unsafe { vld1q_f32([tw2.re, tw2.im, tw2.re, tw2.im].as_ptr()) }; - let twiddle3 = unsafe { vld1q_f32([tw3.re, tw3.im, tw3.re, tw3.im].as_ptr()) }; - let twiddle1c = unsafe { vld1q_f32([tw1.re, -tw1.im, tw1.re, -tw1.im].as_ptr()) }; - let twiddle2c = unsafe { vld1q_f32([tw2.re, -tw2.im, tw2.re, -tw2.im].as_ptr()) }; - let twiddle3c = unsafe { vld1q_f32([tw3.re, -tw3.im, tw3.re, -tw3.im].as_ptr()) }; - Self { - direction, - bf4, - bf8, - rotate90, - twiddle01, - twiddle23, - twiddle01conj, - twiddle23conj, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + let tw4: Complex = twiddles::compute_twiddle(4, 16, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); + + unsafe { + Self { + bf4: NeonF32Butterfly4::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle6: pack_32(tw6, tw6), + twiddle9: pack_32(tw9, tw9), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14 }); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7}); - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl NeonArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - let values = interleave_complex_f32!(input_packed, 8, {0, 1, 2, 3 ,4 ,5 ,6 ,7}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11,12,13,14, 15}); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float32x4_t; 8]) -> [float32x4_t; 8] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1503 = extract_hi_hi_f32(input[7], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - - let in_evens = [in0002, in0406, in0810, in1214]; - - // step 2: column FFTs - let evens = self.bf8.perform_fft_direct(in_evens); - let mut odds1 = self.bf4.perform_fft_direct(in0105, in0913); - let mut odds3 = self.bf4.perform_fft_direct(in1503, in0711); - - // step 3: apply twiddle factors - odds1[0] = NeonVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = NeonVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle23conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ] + }; - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, and transpose + let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = NeonVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = NeonVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = NeonVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + //////////////////////////////////////////////////////////// + let mut store = |i: usize, vectors: [float32x4_t; 4]| { + buffer.store_complex(vectors[0], i + 0); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + }; + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf4 + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3]); + store(0, out0); - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp0[1]), - vaddq_f32(evens[3], temp1[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp0[1]), - vsubq_f32(evens[3], temp1[1]), - ] - } + let out1 = self + .bf4 + .perform_parallel_fft_direct([mid4, mid5, mid6, mid7]); + store(2, out1); + } + + pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i: usize| { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 16)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 4), buffer.load_complex(i + 20)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 24)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 28)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + }; - #[inline(always)] - unsafe fn perform_parallel_fft_direct(&self, input: [float32x4_t; 16]) -> [float32x4_t; 16] { - // we're going to hardcode a step of split radix - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf8.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in2, in3] = load(2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); + tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddle6); + tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = NeonVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddle9); + + // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill + let [in0, in1] = load(0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddle3); + + //////////////////////////////////////////////////////////// + let mut store = |i, values_a: [float32x4_t; 4], values_b: [float32x4_t; 4]| { + for n in 0..4 { + let [a, b] = transpose_complex_2x2_f32(values_a[n], values_b[n]); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 16); + } + }; + // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self .bf4 - .perform_parallel_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self + .perform_parallel_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + let out1 = self .bf4 - .perform_parallel_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); + .perform_parallel_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + store(0, out0, out1); - odds1[3] = NeonVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = NeonVector::mul_complex(odds3[3], self.twiddle3c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp2[0]), - vaddq_f32(evens[3], temp3[0]), - vaddq_f32(evens[4], temp0[1]), - vaddq_f32(evens[5], temp1[1]), - vaddq_f32(evens[6], temp2[1]), - vaddq_f32(evens[7], temp3[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp2[0]), - vsubq_f32(evens[3], temp3[0]), - vsubq_f32(evens[4], temp0[1]), - vsubq_f32(evens[5], temp1[1]), - vsubq_f32(evens[6], temp2[1]), - vsubq_f32(evens[7], temp3[1]), - ] + let out2 = self + .bf4 + .perform_parallel_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + let out3 = self + .bf4 + .perform_parallel_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + store(2, out2, out3); } } - // _ __ __ _ _ _ _ _ // / |/ /_ / /_ | || | | |__ (_) |_ // | | '_ \ _____ | '_ \| || |_| '_ \| | __| @@ -2525,143 +2472,101 @@ impl NeonF32Butterfly16 { // pub struct NeonF64Butterfly16 { - direction: FftDirection, bf4: NeonF64Butterfly4, - bf8: NeonF64Butterfly8, - rotate90: Rotate90F64, twiddle1: float64x2_t, - twiddle2: float64x2_t, twiddle3: float64x2_t, - twiddle1c: float64x2_t, - twiddle2c: float64x2_t, - twiddle3c: float64x2_t, + twiddle9: float64x2_t, } boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly16, 16, |this: &NeonF64Butterfly16<_>| this + .bf4 .direction); boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly16, 16, |this: &NeonF64Butterfly16<_>| this + .bf4 .direction); impl NeonF64Butterfly16 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = NeonF64Butterfly8::new(direction); - let bf4 = NeonF64Butterfly4::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(1, 16, direction) as *const _ as *const f64) - }; - let twiddle2 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(2, 16, direction) as *const _ as *const f64) - }; - let twiddle3 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(3, 16, direction) as *const _ as *const f64) - }; - let twiddle1c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(1, 16, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle2c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(2, 16, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle3c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(3, 16, direction).conj() as *const _ - as *const f64, - ) - }; + let tw1: Complex = twiddles::compute_twiddle(1, 16, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 16, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 16, direction); - Self { - direction, - bf4, - bf8, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle1c, - twiddle2c, - twiddle3c, + unsafe { + Self { + bf4: NeonF64Butterfly4::new(direction), + twiddle1: pack_64(tw1), + twiddle3: pack_64(tw3), + twiddle9: pack_64(tw9), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let values = - read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - } + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 4), + buffer.load_complex(i + 8), + buffer.load_complex(i + 12), + ] + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float64x2_t; 16]) -> [float64x2_t; 16] { - // we're going to hardcode a step of split radix + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = self.bf4.rotate.rotate_45(tmp1[2]); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddle3); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = self.bf4.rotate.rotate_135(tmp3[2]); + tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddle9); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = self.bf4.rotate.rotate_45(tmp2[1]); + tmp2[2] = self.bf4.rotate.rotate(tmp2[2]); + tmp2[3] = self.bf4.rotate.rotate_135(tmp2[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i: usize, vectors: [float64x2_t; 4]| { + buffer.store_complex(vectors[0], i + 0); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf8.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - ]); - let mut odds1 = self - .bf4 - .perform_fft_direct(input[1], input[5], input[9], input[13]); - let mut odds3 = self + // Size-4 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self .bf4 - .perform_fft_direct(input[15], input[3], input[7], input[11]); - - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0]]); + store(0, out0); - odds1[3] = NeonVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = NeonVector::mul_complex(odds3[3], self.twiddle3c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); + let out1 = self + .bf4 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1]]); + store(1, out1); - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); + let out2 = self + .bf4 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2]]); + store(2, out2); - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f64(evens[0], temp0[0]), - vaddq_f64(evens[1], temp1[0]), - vaddq_f64(evens[2], temp2[0]), - vaddq_f64(evens[3], temp3[0]), - vaddq_f64(evens[4], temp0[1]), - vaddq_f64(evens[5], temp1[1]), - vaddq_f64(evens[6], temp2[1]), - vaddq_f64(evens[7], temp3[1]), - vsubq_f64(evens[0], temp0[0]), - vsubq_f64(evens[1], temp1[0]), - vsubq_f64(evens[2], temp2[0]), - vsubq_f64(evens[3], temp3[0]), - vsubq_f64(evens[4], temp0[1]), - vsubq_f64(evens[5], temp1[1]), - vsubq_f64(evens[6], temp2[1]), - vsubq_f64(evens[7], temp3[1]), - ] + let out3 = self + .bf4 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3]]); + store(3, out3); } } @@ -2673,249 +2578,200 @@ impl NeonF64Butterfly16 { // pub struct NeonF32Butterfly24 { - direction: FftDirection, + bf4: NeonF32Butterfly4, bf6: NeonF32Butterfly6, - bf12: NeonF32Butterfly12, - rotate90: Rotate90F32, - twiddle01: float32x4_t, - twiddle23: float32x4_t, - twiddle45: float32x4_t, - twiddle01conj: float32x4_t, - twiddle23conj: float32x4_t, - twiddle45conj: float32x4_t, + twiddles_packed: [float32x4_t; 9], twiddle1: float32x4_t, twiddle2: float32x4_t, twiddle4: float32x4_t, twiddle5: float32x4_t, - twiddle1c: float32x4_t, - twiddle2c: float32x4_t, - twiddle4c: float32x4_t, - twiddle5c: float32x4_t, + twiddle8: float32x4_t, + twiddle10: float32x4_t, } boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly24, 24, |this: &NeonF32Butterfly24<_>| { - this.direction + this.bf4.direction }); boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly24, 24, |this: &NeonF32Butterfly24<_>| this + .bf4 .direction); impl NeonF32Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let tw0 = Complex { re: 1.0, im: 0.0 }; - let tw1 = twiddles::compute_twiddle(1, 24, direction); - let tw2 = twiddles::compute_twiddle(2, 24, direction); - let tw3 = twiddles::compute_twiddle(3, 24, direction); - let tw4 = twiddles::compute_twiddle(4, 24, direction); - let tw5 = twiddles::compute_twiddle(5, 24, direction); + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; + let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 24, direction); + let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 24, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 24, direction); unsafe { Self { - direction, + bf4: NeonF32Butterfly4::new(direction), bf6: NeonF32Butterfly6::new(direction), - bf12: NeonF32Butterfly12::new(direction), - rotate90: Rotate90F32::new(direction == FftDirection::Inverse), - twiddle01: pack32(tw0, tw1), - twiddle23: pack32(tw2, tw3), - twiddle45: pack32(tw4, tw5), - twiddle01conj: pack32(tw0.conj(), tw1.conj()), - twiddle23conj: pack32(tw2.conj(), tw3.conj()), - twiddle45conj: pack32(tw4.conj(), tw5.conj()), - twiddle1: pack32(tw1, tw1), - twiddle2: pack32(tw2, tw2), - twiddle4: pack32(tw4, tw4), - twiddle5: pack32(tw5, tw5), - twiddle1c: pack32(tw1.conj(), tw1.conj()), - twiddle2c: pack32(tw2.conj(), tw2.conj()), - twiddle4c: pack32(tw4.conj(), tw4.conj()), - twiddle5c: pack32(tw5.conj(), tw5.conj()), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle4: pack_32(tw4, tw4), + twiddle5: pack_32(tw5, tw5), + twiddle8: pack_32(tw8, tw8), + twiddle10: pack_32(tw10, tw10), } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let input_packed = - read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11}); - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl NeonArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46}); - - let values = - interleave_complex_f32!(input_packed, 12, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = - separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float32x4_t; 12]) -> [float32x4_t; 12] { - // we're going to hardcode a step of split radix - - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - - let in2303 = extract_hi_hi_f32(input[11], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - - let in_evens = [in0002, in0406, in0810, in1214, in1618, in2022]; - - // step 2: column FFTs - let evens = self.bf12.perform_fft_direct(in_evens); - let mut odds1 = self.bf6.perform_fft_direct(in0105, in0913, in1721); - let mut odds3 = self.bf6.perform_fft_direct(in2303, in0711, in1519); - - // step 3: apply twiddle factors - odds1[0] = NeonVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = NeonVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle23conj); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle45); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle45conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp2[0]), - vaddq_f32(evens[3], temp0[1]), - vaddq_f32(evens[4], temp1[1]), - vaddq_f32(evens[5], temp2[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp2[0]), - vsubq_f32(evens[3], temp0[1]), - vsubq_f32(evens[4], temp1[1]), - vsubq_f32(evens[5], temp2[1]), - ] - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct( - &self, - input: [float32x4_t; 24], - ) -> [float32x4_t; 24] { - // we're going to hardcode a step of split radix + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ] + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf12.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_parallel_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_parallel_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors, transpose + let mut tmp1 = self.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self.bf4.perform_parallel_fft_direct(load(4)); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = NeonVector::mul_complex(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp0 = self.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = NeonVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = NeonVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = NeonVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [float32x4_t; 6]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + }; - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = vaddq_f32(vec, rotated); - vmulq_f32(sum, vld1q_dup_f32(&0.5f32.sqrt())) + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_parallel_fft_direct(mid0, mid1, mid2, mid3, mid4, mid5); + store(0, out0); + + let out1 = self + .bf6 + .perform_parallel_fft_direct(mid6, mid7, mid8, mid9, mid10, mid11); + store(2, out1); + } + + #[inline(always)] + pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i: usize| { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 24)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 6), buffer.load_complex(i + 30)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 12), buffer.load_complex(i + 36)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 18), buffer.load_complex(i + 42)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate_both(vec); - let sum = vsubq_f32(vec, rotated); - vmulq_f32(sum, vld1q_dup_f32(&0.5f32.sqrt())) + + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in0, in1] = load(0); + let tmp0 = self.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_both_45(tmp1[3]); + + let [in2, in3] = load(2); + let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = NeonVector::mul_complex(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate_both(tmp2[3]); + tmp3[1] = self.bf4.rotate.rotate_both_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate_both(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_both_135(tmp3[3]); + + let [in4, in5] = load(4); + let mut tmp4 = self.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = NeonVector::mul_complex(tmp4[1], self.twiddle4); + tmp4[2] = NeonVector::mul_complex(tmp4[2], self.twiddle8); + tmp4[3] = NeonVector::neg(tmp4[3]); + tmp5[1] = NeonVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = NeonVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_both_225(tmp5[3]); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors_a: [float32x4_t; 6], vectors_b: [float32x4_t; 6]| { + for n in 0..6 { + let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 24); + } }; - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = NeonVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = NeonVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = NeonVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = NeonVector::mul_complex(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp2[0]), - vaddq_f32(evens[3], temp3[0]), - vaddq_f32(evens[4], temp4[0]), - vaddq_f32(evens[5], temp5[0]), - vaddq_f32(evens[6], temp0[1]), - vaddq_f32(evens[7], temp1[1]), - vaddq_f32(evens[8], temp2[1]), - vaddq_f32(evens[9], temp3[1]), - vaddq_f32(evens[10], temp4[1]), - vaddq_f32(evens[11], temp5[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp2[0]), - vsubq_f32(evens[3], temp3[0]), - vsubq_f32(evens[4], temp4[0]), - vsubq_f32(evens[5], temp5[0]), - vsubq_f32(evens[6], temp0[1]), - vsubq_f32(evens[7], temp1[1]), - vsubq_f32(evens[8], temp2[1]), - vsubq_f32(evens[9], temp3[1]), - vsubq_f32(evens[10], temp4[1]), - vsubq_f32(evens[11], temp5[1]), - ] + // Size-6 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_parallel_fft_direct(tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]); + let out1 = self + .bf6 + .perform_parallel_fft_direct(tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]); + store(0, out0, out1); + + let out2 = self + .bf6 + .perform_parallel_fft_direct(tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]); + let out3 = self + .bf6 + .perform_parallel_fft_direct(tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]); + store(2, out2, out3); } } @@ -2927,149 +2783,124 @@ impl NeonF32Butterfly24 { // pub struct NeonF64Butterfly24 { - direction: FftDirection, + bf4: NeonF64Butterfly4, bf6: NeonF64Butterfly6, - bf12: NeonF64Butterfly12, - rotate90: Rotate90F64, twiddle1: float64x2_t, twiddle2: float64x2_t, twiddle4: float64x2_t, twiddle5: float64x2_t, - twiddle1c: float64x2_t, - twiddle2c: float64x2_t, - twiddle4c: float64x2_t, - twiddle5c: float64x2_t, + twiddle8: float64x2_t, + twiddle10: float64x2_t, } boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly24, 24, |this: &NeonF64Butterfly24<_>| { - this.direction + this.bf4.direction }); boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly24, 24, |this: &NeonF64Butterfly24<_>| this + .bf4 .direction); impl NeonF64Butterfly24 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let twiddle1 = twiddles::compute_twiddle(1, 24, direction); - let twiddle2 = twiddles::compute_twiddle(2, 24, direction); - let twiddle4 = twiddles::compute_twiddle(4, 24, direction); - let twiddle5 = twiddles::compute_twiddle(5, 24, direction); + let tw1: Complex = twiddles::compute_twiddle(1, 24, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 24, direction); + let tw4: Complex = twiddles::compute_twiddle(4, 24, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 24, direction); + let tw8: Complex = twiddles::compute_twiddle(8, 24, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 24, direction); + unsafe { Self { - direction, + bf4: NeonF64Butterfly4::new(direction), bf6: NeonF64Butterfly6::new(direction), - bf12: NeonF64Butterfly12::new(direction), - rotate90: Rotate90F64::new(direction == FftDirection::Inverse), - twiddle1: pack64(twiddle1), - twiddle2: pack64(twiddle2), - twiddle4: pack64(twiddle4), - twiddle5: pack64(twiddle5), - twiddle1c: pack64(twiddle1.conj()), - twiddle2c: pack64(twiddle2.conj()), - twiddle4c: pack64(twiddle4.conj()), - twiddle5c: pack64(twiddle5.conj()), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle4: pack_64(tw4), + twiddle5: pack_64(tw5), + twiddle8: pack_64(tw8), + twiddle10: pack_64(tw10), } } } #[inline(always)] - pub(crate) unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 6), + buffer.load_complex(i + 12), + buffer.load_complex(i + 18), + ] + }; - let out = self.perform_fft_direct(values); + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf4.perform_fft_direct(load(1)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = self.bf4.rotate.rotate_45(tmp1[3]); + + let mut tmp2 = self.bf4.perform_fft_direct(load(2)); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = NeonVector::mul_complex(tmp2[2], self.twiddle4); + tmp2[3] = self.bf4.rotate.rotate(tmp2[3]); + + let mut tmp4 = self.bf4.perform_fft_direct(load(4)); + tmp4[1] = NeonVector::mul_complex(tmp4[1], self.twiddle4); + tmp4[2] = NeonVector::mul_complex(tmp4[2], self.twiddle8); + tmp4[3] = NeonVector::neg(tmp4[3]); + + let mut tmp5 = self.bf4.perform_fft_direct(load(5)); + tmp5[1] = NeonVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = NeonVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = self.bf4.rotate.rotate_225(tmp5[3]); + + let mut tmp3 = self.bf4.perform_fft_direct(load(3)); + tmp3[1] = self.bf4.rotate.rotate_45(tmp3[1]); + tmp3[2] = self.bf4.rotate.rotate(tmp3[2]); + tmp3[3] = self.bf4.rotate.rotate_135(tmp3[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [float64x2_t; 6]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + }; - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); - } + // Size-6 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self + .bf6 + .perform_fft_direct([tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0]]); + store(0, out0); - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float64x2_t; 24]) -> [float64x2_t; 24] { - // we're going to hardcode a step of split radix + let out1 = self + .bf6 + .perform_fft_direct([tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1]]); + store(1, out1); - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf12.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], - ]); - let mut odds1 = self.bf6.perform_fft_direct( - input[1], input[5], input[9], input[13], input[17], input[21], - ); - let mut odds3 = self.bf6.perform_fft_direct( - input[23], input[3], input[7], input[11], input[15], input[19], - ); - - // twiddle factor helpers - let rotate45 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = vaddq_f64(vec, rotated); - vmulq_f64(sum, vld1q_dup_f64(&0.5f64.sqrt())) - }; - let rotate315 = |vec| { - let rotated = self.rotate90.rotate(vec); - let sum = vsubq_f64(vec, rotated); - vmulq_f64(sum, vld1q_dup_f64(&0.5f64.sqrt())) - }; + let out2 = self + .bf6 + .perform_fft_direct([tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2]]); + store(2, out2); - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = rotate45(odds1[3]); - odds3[3] = rotate315(odds3[3]); - - odds1[4] = NeonVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = NeonVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = NeonVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = NeonVector::mul_complex(odds3[5], self.twiddle5c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f64(evens[0], temp0[0]), - vaddq_f64(evens[1], temp1[0]), - vaddq_f64(evens[2], temp2[0]), - vaddq_f64(evens[3], temp3[0]), - vaddq_f64(evens[4], temp4[0]), - vaddq_f64(evens[5], temp5[0]), - vaddq_f64(evens[6], temp0[1]), - vaddq_f64(evens[7], temp1[1]), - vaddq_f64(evens[8], temp2[1]), - vaddq_f64(evens[9], temp3[1]), - vaddq_f64(evens[10], temp4[1]), - vaddq_f64(evens[11], temp5[1]), - vsubq_f64(evens[0], temp0[0]), - vsubq_f64(evens[1], temp1[0]), - vsubq_f64(evens[2], temp2[0]), - vsubq_f64(evens[3], temp3[0]), - vsubq_f64(evens[4], temp4[0]), - vsubq_f64(evens[5], temp5[0]), - vsubq_f64(evens[6], temp0[1]), - vsubq_f64(evens[7], temp1[1]), - vsubq_f64(evens[8], temp2[1]), - vsubq_f64(evens[9], temp3[1]), - vsubq_f64(evens[10], temp4[1]), - vsubq_f64(evens[11], temp5[1]), - ] + let out3 = self + .bf6 + .perform_fft_direct([tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3]]); + store(3, out3); } } @@ -3081,49 +2912,35 @@ impl NeonF64Butterfly24 { // pub struct NeonF32Butterfly32 { - direction: FftDirection, bf8: NeonF32Butterfly8, - bf16: NeonF32Butterfly16, - rotate90: Rotate90F32, - twiddle01: float32x4_t, - twiddle23: float32x4_t, - twiddle45: float32x4_t, - twiddle67: float32x4_t, - twiddle01conj: float32x4_t, - twiddle23conj: float32x4_t, - twiddle45conj: float32x4_t, - twiddle67conj: float32x4_t, + twiddles_packed: [float32x4_t; 12], twiddle1: float32x4_t, twiddle2: float32x4_t, twiddle3: float32x4_t, - twiddle4: float32x4_t, twiddle5: float32x4_t, twiddle6: float32x4_t, twiddle7: float32x4_t, - twiddle1c: float32x4_t, - twiddle2c: float32x4_t, - twiddle3c: float32x4_t, - twiddle4c: float32x4_t, - twiddle5c: float32x4_t, - twiddle6c: float32x4_t, - twiddle7c: float32x4_t, + twiddle9: float32x4_t, + twiddle10: float32x4_t, + twiddle14: float32x4_t, + twiddle15: float32x4_t, + twiddle18: float32x4_t, + twiddle21: float32x4_t, } boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly32, 32, |this: &NeonF32Butterfly32<_>| this + .bf8 + .bf4 .direction); boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly32, 32, |this: &NeonF32Butterfly32<_>| this + .bf8 + .bf4 .direction); impl NeonF32Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f32::(); - let bf8 = NeonF32Butterfly8::new(direction); - let bf16 = NeonF32Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F32::new(true) - } else { - Rotate90F32::new(false) - }; + let tw0: Complex = Complex { re: 1.0, im: 0.0 }; let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); @@ -3131,264 +2948,202 @@ impl NeonF32Butterfly32 { let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); - let twiddle01 = unsafe { vld1q_f32([1.0, 0.0, tw1.re, tw1.im].as_ptr()) }; - let twiddle23 = unsafe { vld1q_f32([tw2.re, tw2.im, tw3.re, tw3.im].as_ptr()) }; - let twiddle45 = unsafe { vld1q_f32([tw4.re, tw4.im, tw5.re, tw5.im].as_ptr()) }; - let twiddle67 = unsafe { vld1q_f32([tw6.re, tw6.im, tw7.re, tw7.im].as_ptr()) }; - let twiddle01conj = unsafe { vld1q_f32([1.0, 0.0, tw1.re, -tw1.im].as_ptr()) }; - let twiddle23conj = unsafe { vld1q_f32([tw2.re, -tw2.im, tw3.re, -tw3.im].as_ptr()) }; - let twiddle45conj = unsafe { vld1q_f32([tw4.re, -tw4.im, tw5.re, -tw5.im].as_ptr()) }; - let twiddle67conj = unsafe { vld1q_f32([tw6.re, -tw6.im, tw7.re, -tw7.im].as_ptr()) }; - let twiddle1 = unsafe { vld1q_f32([tw1.re, tw1.im, tw1.re, tw1.im].as_ptr()) }; - let twiddle2 = unsafe { vld1q_f32([tw2.re, tw2.im, tw2.re, tw2.im].as_ptr()) }; - let twiddle3 = unsafe { vld1q_f32([tw3.re, tw3.im, tw3.re, tw3.im].as_ptr()) }; - let twiddle4 = unsafe { vld1q_f32([tw4.re, tw4.im, tw4.re, tw4.im].as_ptr()) }; - let twiddle5 = unsafe { vld1q_f32([tw5.re, tw5.im, tw5.re, tw5.im].as_ptr()) }; - let twiddle6 = unsafe { vld1q_f32([tw6.re, tw6.im, tw6.re, tw6.im].as_ptr()) }; - let twiddle7 = unsafe { vld1q_f32([tw7.re, tw7.im, tw7.re, tw7.im].as_ptr()) }; - let twiddle1c = unsafe { vld1q_f32([tw1.re, -tw1.im, tw1.re, -tw1.im].as_ptr()) }; - let twiddle2c = unsafe { vld1q_f32([tw2.re, -tw2.im, tw2.re, -tw2.im].as_ptr()) }; - let twiddle3c = unsafe { vld1q_f32([tw3.re, -tw3.im, tw3.re, -tw3.im].as_ptr()) }; - let twiddle4c = unsafe { vld1q_f32([tw4.re, -tw4.im, tw4.re, -tw4.im].as_ptr()) }; - let twiddle5c = unsafe { vld1q_f32([tw5.re, -tw5.im, tw5.re, -tw5.im].as_ptr()) }; - let twiddle6c = unsafe { vld1q_f32([tw6.re, -tw6.im, tw6.re, -tw6.im].as_ptr()) }; - let twiddle7c = unsafe { vld1q_f32([tw7.re, -tw7.im, tw7.re, -tw7.im].as_ptr()) }; - Self { - direction, - bf8, - bf16, - rotate90, - twiddle01, - twiddle23, - twiddle45, - twiddle67, - twiddle01conj, - twiddle23conj, - twiddle45conj, - twiddle67conj, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + let tw8: Complex = twiddles::compute_twiddle(8, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw12: Complex = twiddles::compute_twiddle(12, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); + unsafe { + Self { + bf8: NeonF32Butterfly8::new(direction), + twiddles_packed: [ + pack_32(tw0, tw1), + pack_32(tw0, tw2), + pack_32(tw0, tw3), + pack_32(tw2, tw3), + pack_32(tw4, tw6), + pack_32(tw6, tw9), + pack_32(tw4, tw5), + pack_32(tw8, tw10), + pack_32(tw12, tw15), + pack_32(tw6, tw7), + pack_32(tw12, tw14), + pack_32(tw18, tw21), + ], + twiddle1: pack_32(tw1, tw1), + twiddle2: pack_32(tw2, tw2), + twiddle3: pack_32(tw3, tw3), + twiddle5: pack_32(tw5, tw5), + twiddle6: pack_32(tw6, tw6), + twiddle7: pack_32(tw7, tw7), + twiddle9: pack_32(tw9, tw9), + twiddle10: pack_32(tw10, tw10), + twiddle14: pack_32(tw14, tw14), + twiddle15: pack_32(tw15, tw15), + twiddle18: pack_32(tw18, tw18), + twiddle21: pack_32(tw21, tw21), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 }); - - let out = self.perform_fft_direct(input_packed); - - write_complex_to_array_strided!(out, buffer, 2, {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); - } - - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous( - &self, - mut buffer: impl NeonArrayMut, - ) { - let input_packed = read_complex_to_array!(buffer, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62}); - - let values = interleave_complex_f32!(input_packed, 16, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - - let out = self.perform_parallel_fft_direct(values); - - let out_sorted = separate_interleaved_complex_f32!(out, {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}); - - write_complex_to_array_strided!(out_sorted, buffer, 2, {0,1,2,3,4,5,6,7,8,9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 }); - } - - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float32x4_t; 16]) -> [float32x4_t; 16] { - // we're going to hardcode a step of split radix + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ] + }; - // step 1: copy and reorder the input into the scratch - let in0002 = extract_lo_lo_f32(input[0], input[1]); - let in0406 = extract_lo_lo_f32(input[2], input[3]); - let in0810 = extract_lo_lo_f32(input[4], input[5]); - let in1214 = extract_lo_lo_f32(input[6], input[7]); - let in1618 = extract_lo_lo_f32(input[8], input[9]); - let in2022 = extract_lo_lo_f32(input[10], input[11]); - let in2426 = extract_lo_lo_f32(input[12], input[13]); - let in2830 = extract_lo_lo_f32(input[14], input[15]); - - let in0105 = extract_hi_hi_f32(input[0], input[2]); - let in0913 = extract_hi_hi_f32(input[4], input[6]); - let in1721 = extract_hi_hi_f32(input[8], input[10]); - let in2529 = extract_hi_hi_f32(input[12], input[14]); - - let in3103 = extract_hi_hi_f32(input[15], input[1]); - let in0711 = extract_hi_hi_f32(input[3], input[5]); - let in1519 = extract_hi_hi_f32(input[7], input[9]); - let in2327 = extract_hi_hi_f32(input[11], input[13]); - - let in_evens = [ - in0002, in0406, in0810, in1214, in1618, in2022, in2426, in2830, - ]; + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp0 = self.bf8.bf4.perform_parallel_fft_direct(load(0)); + tmp0[1] = NeonVector::mul_complex(tmp0[1], self.twiddles_packed[0]); + tmp0[2] = NeonVector::mul_complex(tmp0[2], self.twiddles_packed[1]); + tmp0[3] = NeonVector::mul_complex(tmp0[3], self.twiddles_packed[2]); + let [mid0, mid1] = transpose_complex_2x2_f32(tmp0[0], tmp0[1]); + let [mid8, mid9] = transpose_complex_2x2_f32(tmp0[2], tmp0[3]); + + let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(load(2)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddles_packed[3]); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddles_packed[4]); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddles_packed[5]); + let [mid2, mid3] = transpose_complex_2x2_f32(tmp1[0], tmp1[1]); + let [mid10, mid11] = transpose_complex_2x2_f32(tmp1[2], tmp1[3]); + + let mut tmp2 = self.bf8.bf4.perform_parallel_fft_direct(load(4)); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddles_packed[6]); + tmp2[2] = NeonVector::mul_complex(tmp2[2], self.twiddles_packed[7]); + tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddles_packed[8]); + let [mid4, mid5] = transpose_complex_2x2_f32(tmp2[0], tmp2[1]); + let [mid12, mid13] = transpose_complex_2x2_f32(tmp2[2], tmp2[3]); + + let mut tmp3 = self.bf8.bf4.perform_parallel_fft_direct(load(6)); + tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddles_packed[9]); + tmp3[2] = NeonVector::mul_complex(tmp3[2], self.twiddles_packed[10]); + tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddles_packed[11]); + let [mid6, mid7] = transpose_complex_2x2_f32(tmp3[0], tmp3[1]); + let [mid14, mid15] = transpose_complex_2x2_f32(tmp3[2], tmp3[3]); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [float32x4_t; 8]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + buffer.store_complex(vectors[6], i + 24); + buffer.store_complex(vectors[7], i + 28); + }; - // step 2: column FFTs - let evens = self.bf16.perform_fft_direct(in_evens); - let mut odds1 = self + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self .bf8 - .perform_fft_direct([in0105, in0913, in1721, in2529]); - let mut odds3 = self - .bf8 - .perform_fft_direct([in3103, in0711, in1519, in2327]); - - // step 3: apply twiddle factors - odds1[0] = NeonVector::mul_complex(odds1[0], self.twiddle01); - odds3[0] = NeonVector::mul_complex(odds3[0], self.twiddle01conj); - - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle23); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle23conj); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle45); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle45conj); + .perform_parallel_fft_direct([mid0, mid1, mid2, mid3, mid4, mid5, mid6, mid7]); + store(0, out0); - odds1[3] = NeonVector::mul_complex(odds1[3], self.twiddle67); - odds3[3] = NeonVector::mul_complex(odds3[3], self.twiddle67conj); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp2[0]), - vaddq_f32(evens[3], temp3[0]), - vaddq_f32(evens[4], temp0[1]), - vaddq_f32(evens[5], temp1[1]), - vaddq_f32(evens[6], temp2[1]), - vaddq_f32(evens[7], temp3[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp2[0]), - vsubq_f32(evens[3], temp3[0]), - vsubq_f32(evens[4], temp0[1]), - vsubq_f32(evens[5], temp1[1]), - vsubq_f32(evens[6], temp2[1]), - vsubq_f32(evens[7], temp3[1]), - ] - } + let out1 = self + .bf8 + .perform_parallel_fft_direct([mid8, mid9, mid10, mid11, mid12, mid13, mid14, mid15]); + store(2, out1); + } + + #[inline(always)] + pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i: usize| { + let [a0, a1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 0), buffer.load_complex(i + 32)); + let [b0, b1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 8), buffer.load_complex(i + 40)); + let [c0, c1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 16), buffer.load_complex(i + 48)); + let [d0, d1] = + transpose_complex_2x2_f32(buffer.load_complex(i + 24), buffer.load_complex(i + 56)); + [[a0, b0, c0, d0], [a1, b1, c1, d1]] + }; - #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct( - &self, - input: [float32x4_t; 32], - ) -> [float32x4_t; 32] { - // we're going to hardcode a step of split radix + // For each pair of columns: load the data, apply our size-4 FFT, apply twiddle factors + let [in0, in1] = load(0); + let tmp0 = self.bf8.bf4.perform_parallel_fft_direct(in0); + let mut tmp1 = self.bf8.bf4.perform_parallel_fft_direct(in1); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddle3); + + let [in2, in3] = load(2); + let mut tmp2 = self.bf8.bf4.perform_parallel_fft_direct(in2); + let mut tmp3 = self.bf8.bf4.perform_parallel_fft_direct(in3); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_both_45(tmp2[2]); + tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddle6); + tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = NeonVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddle9); + + let [in4, in5] = load(4); + let mut tmp4 = self.bf8.bf4.perform_parallel_fft_direct(in4); + let mut tmp5 = self.bf8.bf4.perform_parallel_fft_direct(in5); + tmp4[1] = self.bf8.bf4.rotate.rotate_both_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate_both(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_both_135(tmp4[3]); + tmp5[1] = NeonVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = NeonVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = NeonVector::mul_complex(tmp5[3], self.twiddle15); + + let [in6, in7] = load(6); + let mut tmp6 = self.bf8.bf4.perform_parallel_fft_direct(in6); + let mut tmp7 = self.bf8.bf4.perform_parallel_fft_direct(in7); + tmp6[1] = NeonVector::mul_complex(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_both_135(tmp6[2]); + tmp6[3] = NeonVector::mul_complex(tmp6[3], self.twiddle18); + tmp7[1] = NeonVector::mul_complex(tmp7[1], self.twiddle7); + tmp7[2] = NeonVector::mul_complex(tmp7[2], self.twiddle14); + tmp7[3] = NeonVector::mul_complex(tmp7[3], self.twiddle21); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors_a: [float32x4_t; 8], vectors_b: [float32x4_t; 8]| { + for n in 0..8 { + let [a, b] = transpose_complex_2x2_f32(vectors_a[n], vectors_b[n]); + buffer.store_complex(a, i + n * 4); + buffer.store_complex(b, i + n * 4 + 32); + } + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_parallel_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], - ]); - let mut odds1 = self.bf8.perform_parallel_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], + // Size-8 FFTs down each pair of transposed columns, storing them as soon as we're done with them + let out0 = self.bf8.perform_parallel_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], ]); - let mut odds3 = self.bf8.perform_parallel_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], + let out1 = self.bf8.perform_parallel_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], ]); + store(0, out0, out1); - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = NeonVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = NeonVector::mul_complex(odds3[3], self.twiddle3c); - - odds1[4] = NeonVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = NeonVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = NeonVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = NeonVector::mul_complex(odds3[5], self.twiddle5c); - - odds1[6] = NeonVector::mul_complex(odds1[6], self.twiddle6); - odds3[6] = NeonVector::mul_complex(odds3[6], self.twiddle6c); - - odds1[7] = NeonVector::mul_complex(odds1[7], self.twiddle7); - odds3[7] = NeonVector::mul_complex(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = parallel_fft2_interleaved_f32(odds1[0], odds3[0]); - let mut temp1 = parallel_fft2_interleaved_f32(odds1[1], odds3[1]); - let mut temp2 = parallel_fft2_interleaved_f32(odds1[2], odds3[2]); - let mut temp3 = parallel_fft2_interleaved_f32(odds1[3], odds3[3]); - let mut temp4 = parallel_fft2_interleaved_f32(odds1[4], odds3[4]); - let mut temp5 = parallel_fft2_interleaved_f32(odds1[5], odds3[5]); - let mut temp6 = parallel_fft2_interleaved_f32(odds1[6], odds3[6]); - let mut temp7 = parallel_fft2_interleaved_f32(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate_both(temp0[1]); - temp1[1] = self.rotate90.rotate_both(temp1[1]); - temp2[1] = self.rotate90.rotate_both(temp2[1]); - temp3[1] = self.rotate90.rotate_both(temp3[1]); - temp4[1] = self.rotate90.rotate_both(temp4[1]); - temp5[1] = self.rotate90.rotate_both(temp5[1]); - temp6[1] = self.rotate90.rotate_both(temp6[1]); - temp7[1] = self.rotate90.rotate_both(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f32(evens[0], temp0[0]), - vaddq_f32(evens[1], temp1[0]), - vaddq_f32(evens[2], temp2[0]), - vaddq_f32(evens[3], temp3[0]), - vaddq_f32(evens[4], temp4[0]), - vaddq_f32(evens[5], temp5[0]), - vaddq_f32(evens[6], temp6[0]), - vaddq_f32(evens[7], temp7[0]), - vaddq_f32(evens[8], temp0[1]), - vaddq_f32(evens[9], temp1[1]), - vaddq_f32(evens[10], temp2[1]), - vaddq_f32(evens[11], temp3[1]), - vaddq_f32(evens[12], temp4[1]), - vaddq_f32(evens[13], temp5[1]), - vaddq_f32(evens[14], temp6[1]), - vaddq_f32(evens[15], temp7[1]), - vsubq_f32(evens[0], temp0[0]), - vsubq_f32(evens[1], temp1[0]), - vsubq_f32(evens[2], temp2[0]), - vsubq_f32(evens[3], temp3[0]), - vsubq_f32(evens[4], temp4[0]), - vsubq_f32(evens[5], temp5[0]), - vsubq_f32(evens[6], temp6[0]), - vsubq_f32(evens[7], temp7[0]), - vsubq_f32(evens[8], temp0[1]), - vsubq_f32(evens[9], temp1[1]), - vsubq_f32(evens[10], temp2[1]), - vsubq_f32(evens[11], temp3[1]), - vsubq_f32(evens[12], temp4[1]), - vsubq_f32(evens[13], temp5[1]), - vsubq_f32(evens[14], temp6[1]), - vsubq_f32(evens[15], temp7[1]), - ] + let out2 = self.bf8.perform_parallel_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], + ]); + let out3 = self.bf8.perform_parallel_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); + store(2, out2, out3); } } @@ -3400,231 +3155,154 @@ impl NeonF32Butterfly32 { // pub struct NeonF64Butterfly32 { - direction: FftDirection, bf8: NeonF64Butterfly8, - bf16: NeonF64Butterfly16, - rotate90: Rotate90F64, twiddle1: float64x2_t, twiddle2: float64x2_t, twiddle3: float64x2_t, - twiddle4: float64x2_t, twiddle5: float64x2_t, twiddle6: float64x2_t, twiddle7: float64x2_t, - twiddle1c: float64x2_t, - twiddle2c: float64x2_t, - twiddle3c: float64x2_t, - twiddle4c: float64x2_t, - twiddle5c: float64x2_t, - twiddle6c: float64x2_t, - twiddle7c: float64x2_t, + twiddle9: float64x2_t, + twiddle10: float64x2_t, + twiddle14: float64x2_t, + twiddle15: float64x2_t, + twiddle18: float64x2_t, + twiddle21: float64x2_t, } boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly32, 32, |this: &NeonF64Butterfly32<_>| this + .bf8 + .bf4 .direction); boilerplate_fft_neon_common_butterfly!(NeonF64Butterfly32, 32, |this: &NeonF64Butterfly32<_>| this + .bf8 + .bf4 .direction); impl NeonF64Butterfly32 { #[inline(always)] pub fn new(direction: FftDirection) -> Self { assert_f64::(); - let bf8 = NeonF64Butterfly8::new(direction); - let bf16 = NeonF64Butterfly16::new(direction); - let rotate90 = if direction == FftDirection::Inverse { - Rotate90F64::new(true) - } else { - Rotate90F64::new(false) - }; - let twiddle1 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(1, 32, direction) as *const _ as *const f64) - }; - let twiddle2 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(2, 32, direction) as *const _ as *const f64) - }; - let twiddle3 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(3, 32, direction) as *const _ as *const f64) - }; - let twiddle4 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(4, 32, direction) as *const _ as *const f64) - }; - let twiddle5 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(5, 32, direction) as *const _ as *const f64) - }; - let twiddle6 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(6, 32, direction) as *const _ as *const f64) - }; - let twiddle7 = unsafe { - vld1q_f64(&twiddles::compute_twiddle::(7, 32, direction) as *const _ as *const f64) - }; - let twiddle1c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(1, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle2c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(2, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle3c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(3, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle4c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(4, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle5c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(5, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle6c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(6, 32, direction).conj() as *const _ - as *const f64, - ) - }; - let twiddle7c = unsafe { - vld1q_f64( - &twiddles::compute_twiddle::(7, 32, direction).conj() as *const _ - as *const f64, - ) - }; + let tw1: Complex = twiddles::compute_twiddle(1, 32, direction); + let tw2: Complex = twiddles::compute_twiddle(2, 32, direction); + let tw3: Complex = twiddles::compute_twiddle(3, 32, direction); + let tw5: Complex = twiddles::compute_twiddle(5, 32, direction); + let tw6: Complex = twiddles::compute_twiddle(6, 32, direction); + let tw7: Complex = twiddles::compute_twiddle(7, 32, direction); + let tw9: Complex = twiddles::compute_twiddle(9, 32, direction); + let tw10: Complex = twiddles::compute_twiddle(10, 32, direction); + let tw14: Complex = twiddles::compute_twiddle(14, 32, direction); + let tw15: Complex = twiddles::compute_twiddle(15, 32, direction); + let tw18: Complex = twiddles::compute_twiddle(18, 32, direction); + let tw21: Complex = twiddles::compute_twiddle(21, 32, direction); - Self { - direction, - bf8, - bf16, - rotate90, - twiddle1, - twiddle2, - twiddle3, - twiddle4, - twiddle5, - twiddle6, - twiddle7, - twiddle1c, - twiddle2c, - twiddle3c, - twiddle4c, - twiddle5c, - twiddle6c, - twiddle7c, + unsafe { + Self { + bf8: NeonF64Butterfly8::new(direction), + twiddle1: pack_64(tw1), + twiddle2: pack_64(tw2), + twiddle3: pack_64(tw3), + twiddle5: pack_64(tw5), + twiddle6: pack_64(tw6), + twiddle7: pack_64(tw7), + twiddle9: pack_64(tw9), + twiddle10: pack_64(tw10), + twiddle14: pack_64(tw14), + twiddle15: pack_64(tw15), + twiddle18: pack_64(tw18), + twiddle21: pack_64(tw21), + } } } #[inline(always)] unsafe fn perform_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { - let values = read_complex_to_array!(buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - - let out = self.perform_fft_direct(values); - - write_complex_to_array!(out, buffer, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - } + // To make the best possible use of registers, we're going to write this algorithm in an unusual way + // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs + // But to reduce the number of times registers get spilled, we have these optimizations: + // 1: Load data as late as possible, not upfront + // 2: Once we're working with a piece of data, make as much progress as possible before moving on + // IE, once we load a column, we should do the FFT down the column, do twiddle factors, and do the pieces of the transpose for that column, all before starting on the next column + // 3: Store data as soon as we're finished with it, rather than waiting for the end + let load = |i| { + [ + buffer.load_complex(i), + buffer.load_complex(i + 8), + buffer.load_complex(i + 16), + buffer.load_complex(i + 24), + ] + }; - #[inline(always)] - unsafe fn perform_fft_direct(&self, input: [float64x2_t; 32]) -> [float64x2_t; 32] { - // we're going to hardcode a step of split radix + // For each column: load the data, apply our size-4 FFT, apply twiddle factors + let mut tmp1 = self.bf8.bf4.perform_fft_direct(load(1)); + tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); + tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddle3); + + let mut tmp2 = self.bf8.bf4.perform_fft_direct(load(2)); + tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[2] = self.bf8.bf4.rotate.rotate_45(tmp2[2]); + tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddle6); + + let mut tmp3 = self.bf8.bf4.perform_fft_direct(load(3)); + tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddle3); + tmp3[2] = NeonVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddle9); + + let mut tmp5 = self.bf8.bf4.perform_fft_direct(load(5)); + tmp5[1] = NeonVector::mul_complex(tmp5[1], self.twiddle5); + tmp5[2] = NeonVector::mul_complex(tmp5[2], self.twiddle10); + tmp5[3] = NeonVector::mul_complex(tmp5[3], self.twiddle15); + + let mut tmp6 = self.bf8.bf4.perform_fft_direct(load(6)); + tmp6[1] = NeonVector::mul_complex(tmp6[1], self.twiddle6); + tmp6[2] = self.bf8.bf4.rotate.rotate_135(tmp6[2]); + tmp6[3] = NeonVector::mul_complex(tmp6[3], self.twiddle18); + + let mut tmp7 = self.bf8.bf4.perform_fft_direct(load(7)); + tmp7[1] = NeonVector::mul_complex(tmp7[1], self.twiddle7); + tmp7[2] = NeonVector::mul_complex(tmp7[2], self.twiddle14); + tmp7[3] = NeonVector::mul_complex(tmp7[3], self.twiddle21); + + let mut tmp4 = self.bf8.bf4.perform_fft_direct(load(4)); + tmp4[1] = self.bf8.bf4.rotate.rotate_45(tmp4[1]); + tmp4[2] = self.bf8.bf4.rotate.rotate(tmp4[2]); + tmp4[3] = self.bf8.bf4.rotate.rotate_135(tmp4[3]); + + // Do the first column last, because no twiddles means fewer temporaries forcing the above data to spill + let tmp0 = self.bf8.bf4.perform_fft_direct(load(0)); + + //////////////////////////////////////////////////////////// + let mut store = |i, vectors: [float64x2_t; 8]| { + buffer.store_complex(vectors[0], i); + buffer.store_complex(vectors[1], i + 4); + buffer.store_complex(vectors[2], i + 8); + buffer.store_complex(vectors[3], i + 12); + buffer.store_complex(vectors[4], i + 16); + buffer.store_complex(vectors[5], i + 20); + buffer.store_complex(vectors[6], i + 24); + buffer.store_complex(vectors[7], i + 28); + }; - // step 1: copy and reorder the input into the scratch - // and - // step 2: column FFTs - let evens = self.bf16.perform_fft_direct([ - input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14], - input[16], input[18], input[20], input[22], input[24], input[26], input[28], input[30], + // Size-8 FFTs down each of our transposed columns, storing them as soon as we're done with them + let out0 = self.bf8.perform_fft_direct([ + tmp0[0], tmp1[0], tmp2[0], tmp3[0], tmp4[0], tmp5[0], tmp6[0], tmp7[0], ]); - let mut odds1 = self.bf8.perform_fft_direct([ - input[1], input[5], input[9], input[13], input[17], input[21], input[25], input[29], + store(0, out0); + + let out1 = self.bf8.perform_fft_direct([ + tmp0[1], tmp1[1], tmp2[1], tmp3[1], tmp4[1], tmp5[1], tmp6[1], tmp7[1], ]); - let mut odds3 = self.bf8.perform_fft_direct([ - input[31], input[3], input[7], input[11], input[15], input[19], input[23], input[27], + store(1, out1); + + let out2 = self.bf8.perform_fft_direct([ + tmp0[2], tmp1[2], tmp2[2], tmp3[2], tmp4[2], tmp5[2], tmp6[2], tmp7[2], ]); + store(2, out2); - // step 3: apply twiddle factors - odds1[1] = NeonVector::mul_complex(odds1[1], self.twiddle1); - odds3[1] = NeonVector::mul_complex(odds3[1], self.twiddle1c); - - odds1[2] = NeonVector::mul_complex(odds1[2], self.twiddle2); - odds3[2] = NeonVector::mul_complex(odds3[2], self.twiddle2c); - - odds1[3] = NeonVector::mul_complex(odds1[3], self.twiddle3); - odds3[3] = NeonVector::mul_complex(odds3[3], self.twiddle3c); - - odds1[4] = NeonVector::mul_complex(odds1[4], self.twiddle4); - odds3[4] = NeonVector::mul_complex(odds3[4], self.twiddle4c); - - odds1[5] = NeonVector::mul_complex(odds1[5], self.twiddle5); - odds3[5] = NeonVector::mul_complex(odds3[5], self.twiddle5c); - - odds1[6] = NeonVector::mul_complex(odds1[6], self.twiddle6); - odds3[6] = NeonVector::mul_complex(odds3[6], self.twiddle6c); - - odds1[7] = NeonVector::mul_complex(odds1[7], self.twiddle7); - odds3[7] = NeonVector::mul_complex(odds3[7], self.twiddle7c); - - // step 4: cross FFTs - let mut temp0 = solo_fft2_f64(odds1[0], odds3[0]); - let mut temp1 = solo_fft2_f64(odds1[1], odds3[1]); - let mut temp2 = solo_fft2_f64(odds1[2], odds3[2]); - let mut temp3 = solo_fft2_f64(odds1[3], odds3[3]); - let mut temp4 = solo_fft2_f64(odds1[4], odds3[4]); - let mut temp5 = solo_fft2_f64(odds1[5], odds3[5]); - let mut temp6 = solo_fft2_f64(odds1[6], odds3[6]); - let mut temp7 = solo_fft2_f64(odds1[7], odds3[7]); - - // apply the butterfly 4 twiddle factor, which is just a rotation - temp0[1] = self.rotate90.rotate(temp0[1]); - temp1[1] = self.rotate90.rotate(temp1[1]); - temp2[1] = self.rotate90.rotate(temp2[1]); - temp3[1] = self.rotate90.rotate(temp3[1]); - temp4[1] = self.rotate90.rotate(temp4[1]); - temp5[1] = self.rotate90.rotate(temp5[1]); - temp6[1] = self.rotate90.rotate(temp6[1]); - temp7[1] = self.rotate90.rotate(temp7[1]); - - //step 5: copy/add/subtract data back to buffer - [ - vaddq_f64(evens[0], temp0[0]), - vaddq_f64(evens[1], temp1[0]), - vaddq_f64(evens[2], temp2[0]), - vaddq_f64(evens[3], temp3[0]), - vaddq_f64(evens[4], temp4[0]), - vaddq_f64(evens[5], temp5[0]), - vaddq_f64(evens[6], temp6[0]), - vaddq_f64(evens[7], temp7[0]), - vaddq_f64(evens[8], temp0[1]), - vaddq_f64(evens[9], temp1[1]), - vaddq_f64(evens[10], temp2[1]), - vaddq_f64(evens[11], temp3[1]), - vaddq_f64(evens[12], temp4[1]), - vaddq_f64(evens[13], temp5[1]), - vaddq_f64(evens[14], temp6[1]), - vaddq_f64(evens[15], temp7[1]), - vsubq_f64(evens[0], temp0[0]), - vsubq_f64(evens[1], temp1[0]), - vsubq_f64(evens[2], temp2[0]), - vsubq_f64(evens[3], temp3[0]), - vsubq_f64(evens[4], temp4[0]), - vsubq_f64(evens[5], temp5[0]), - vsubq_f64(evens[6], temp6[0]), - vsubq_f64(evens[7], temp7[0]), - vsubq_f64(evens[8], temp0[1]), - vsubq_f64(evens[9], temp1[1]), - vsubq_f64(evens[10], temp2[1]), - vsubq_f64(evens[11], temp3[1]), - vsubq_f64(evens[12], temp4[1]), - vsubq_f64(evens[13], temp5[1]), - vsubq_f64(evens[14], temp6[1]), - vsubq_f64(evens[15], temp7[1]), - ] + let out3 = self.bf8.perform_fft_direct([ + tmp0[3], tmp1[3], tmp2[3], tmp3[3], tmp4[3], tmp5[3], tmp6[3], tmp7[3], + ]); + store(3, out3); } } diff --git a/src/neon/neon_utils.rs b/src/neon/neon_utils.rs index 5e2ed925..0bc8db2e 100644 --- a/src/neon/neon_utils.rs +++ b/src/neon/neon_utils.rs @@ -71,6 +71,27 @@ impl Rotate90F32 { vreinterpretq_u32_f32(self.sign_both), )) } + + #[inline(always)] + pub unsafe fn rotate_both_45(&self, values: float32x4_t) -> float32x4_t { + let rotated = self.rotate_both(values); + let sum = vaddq_f32(rotated, values); + vmulq_f32(sum, vmovq_n_f32(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_135(&self, values: float32x4_t) -> float32x4_t { + let rotated = self.rotate_both(values); + let diff = vsubq_f32(rotated, values); + vmulq_f32(diff, vmovq_n_f32(0.5f32.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_both_225(&self, values: float32x4_t) -> float32x4_t { + let rotated = self.rotate_both(values); + let diff = vaddq_f32(rotated, values); + vmulq_f32(diff, vmovq_n_f32(-(0.5f32.sqrt()))) + } } // Pack low (1st) complex @@ -202,6 +223,27 @@ impl Rotate90F64 { vreinterpretq_u64_f64(self.sign), )) } + + #[inline(always)] + pub unsafe fn rotate_45(&self, values: float64x2_t) -> float64x2_t { + let rotated = self.rotate(values); + let sum = vaddq_f64(rotated, values); + vmulq_f64(sum, vmovq_n_f64(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_135(&self, values: float64x2_t) -> float64x2_t { + let rotated = self.rotate(values); + let diff = vsubq_f64(rotated, values); + vmulq_f64(diff, vmovq_n_f64(0.5f64.sqrt())) + } + + #[inline(always)] + pub unsafe fn rotate_225(&self, values: float64x2_t) -> float64x2_t { + let rotated = self.rotate(values); + let diff = vaddq_f64(rotated, values); + vmulq_f64(diff, vmovq_n_f64(-(0.5f64.sqrt()))) + } } #[cfg(test)] diff --git a/src/neon/neon_vector.rs b/src/neon/neon_vector.rs index c6b90e5d..0cf7edd1 100644 --- a/src/neon/neon_vector.rs +++ b/src/neon/neon_vector.rs @@ -147,6 +147,9 @@ pub trait NeonVector: Copy + Debug + Send + Sync { unsafe fn store_partial_lo_complex(ptr: *mut Complex, data: Self); unsafe fn store_partial_hi_complex(ptr: *mut Complex, data: Self); + // math ops + unsafe fn neg(a: Self) -> Self; + /// Generates a chunk of twiddle factors starting at (X,Y) and incrementing X `COMPLEX_PER_VECTOR` times. /// The result will be [twiddle(x*y, len), twiddle((x+1)*y, len), twiddle((x+2)*y, len), ...] for as many complex numbers fit in a vector unsafe fn make_mixedradix_twiddle_chunk( @@ -212,6 +215,11 @@ impl NeonVector for float32x4_t { vst1_f32(ptr as *mut f32, high); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + vnegq_f32(a) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, @@ -315,6 +323,11 @@ impl NeonVector for float64x2_t { unimplemented!("Impossible to do a partial store of complex f64's"); } + #[inline(always)] + unsafe fn neg(a: Self) -> Self { + vnegq_f64(a) + } + #[inline(always)] unsafe fn make_mixedradix_twiddle_chunk( x: usize, From 38bcad70e5970abd2c4a5ba882d48929884afc90 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Sun, 25 Feb 2024 16:57:26 -0800 Subject: [PATCH 08/13] Cargo fmt --- src/neon/neon_butterflies.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index dabead51..64ed6996 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -702,7 +702,10 @@ impl NeonF32Butterfly4 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_direct(&self, values: [float32x4_t; 4]) -> [float32x4_t; 4] { + pub(crate) unsafe fn perform_parallel_fft_direct( + &self, + values: [float32x4_t; 4], + ) -> [float32x4_t; 4] { //we're going to hardcode a step of mixed radix //aka we're going to do the six step algorithm @@ -2310,9 +2313,9 @@ pub struct NeonF32Butterfly16 { twiddle9: float32x4_t, } -boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16< - _, ->| this.bf4.direction); +boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this + .bf4 + .direction); boilerplate_fft_neon_common_butterfly!(NeonF32Butterfly16, 16, |this: &NeonF32Butterfly16<_>| this .bf4 .direction); @@ -2399,7 +2402,10 @@ impl NeonF32Butterfly16 { store(2, out1); } - pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl NeonArrayMut, + ) { // To make the best possible use of registers, we're going to write this algorithm in an unusual way // It's 4x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-4 FFTs again // But to reduce the number of times registers get spilled, we have these optimizations: @@ -2699,7 +2705,10 @@ impl NeonF32Butterfly24 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl NeonArrayMut, + ) { // To make the best possible use of registers, we're going to write this algorithm in an unusual way // It's 6x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-6 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: @@ -3061,7 +3070,10 @@ impl NeonF32Butterfly32 { } #[inline(always)] - pub(crate) unsafe fn perform_parallel_fft_contiguous(&self, mut buffer: impl NeonArrayMut) { + pub(crate) unsafe fn perform_parallel_fft_contiguous( + &self, + mut buffer: impl NeonArrayMut, + ) { // To make the best possible use of registers, we're going to write this algorithm in an unusual way // It's 8x4 mixed radix, so we're going to do the usual steps of size-4 FFTs down the columns, apply twiddle factors, then transpose and do size-8 FFTs // But to reduce the number of times registers get spilled, we have these optimizations: From a08065b909c30c0066a7b9e82b42ca6c5cd277e5 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Mon, 26 Feb 2024 19:37:19 -0800 Subject: [PATCH 09/13] Unconditionally use butterfly32 in the neon planner --- src/neon/neon_planner.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/neon/neon_planner.rs b/src/neon/neon_planner.rs index d7285568..6f3634f1 100644 --- a/src/neon/neon_planner.rs +++ b/src/neon/neon_planner.rs @@ -16,7 +16,6 @@ use crate::math_utils::{PrimeFactor, PrimeFactors}; const MIN_RADIX4_BITS: u32 = 6; // smallest size to consider radix 4 an option is 2^6 = 64 const MAX_RADER_PRIME_FACTOR: usize = 23; // don't use Raders if the inner fft length has prime factor larger than this -const RADIX4_USE_BUTTERFLY32_FROM: u32 = 18; // Use length 32 butterfly starting from this length /// A Recipe is a structure that describes the design of a FFT, without actually creating it. /// It is used as a middle step in the planning process. @@ -666,11 +665,7 @@ impl FftPlannerNeon { // main case: if len is a power of 4, use a base of 16, otherwise use a base of 8 _ => { if p2 % 2 == 1 { - if p2 >= RADIX4_USE_BUTTERFLY32_FROM { - 32 - } else { - 8 - } + 32 } else { 16 } From c7e48e926717e2c8fe571d61e48ed868fd3ca158 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Mon, 26 Feb 2024 19:41:02 -0800 Subject: [PATCH 10/13] Unconditionally use butterfly32 in the wasm simd planner --- src/wasm_simd/wasm_simd_planner.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/wasm_simd/wasm_simd_planner.rs b/src/wasm_simd/wasm_simd_planner.rs index f44f65bb..da65be18 100644 --- a/src/wasm_simd/wasm_simd_planner.rs +++ b/src/wasm_simd/wasm_simd_planner.rs @@ -11,7 +11,6 @@ use std::{any::TypeId, collections::HashMap, sync::Arc}; const MIN_RADIX4_BITS: u32 = 6; // smallest size to consider radix 4 an option is 2^6 = 64 const MAX_RADER_PRIME_FACTOR: usize = 23; // don't use Raders if the inner fft length has prime factor larger than this -const RADIX4_USE_BUTTERFLY32_FROM: u32 = 18; // Use length 32 butterfly starting from this power of 2 /// A Recipe is a structure that describes the design of a FFT, without actually creating it. /// It is used as a middle step in the planning process. @@ -638,11 +637,7 @@ impl FftPlannerWasmSimd { // main case: if len is a power of 4, use a base of 16, otherwise use a base of 8 _ => { if p2 % 2 == 1 { - if p2 >= RADIX4_USE_BUTTERFLY32_FROM { - 32 - } else { - 8 - } + 32 } else { 16 } From e2426649c30aae4128429db8ba71ce4a4aa4a344 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Mon, 26 Feb 2024 19:43:48 -0800 Subject: [PATCH 11/13] Unconditionally use butterfly32 in the sse planner --- src/sse/sse_planner.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/sse/sse_planner.rs b/src/sse/sse_planner.rs index 1416d6b7..1628f691 100644 --- a/src/sse/sse_planner.rs +++ b/src/sse/sse_planner.rs @@ -16,7 +16,6 @@ use crate::math_utils::{PrimeFactor, PrimeFactors}; const MIN_RADIX4_BITS: u32 = 6; // smallest size to consider radix 4 an option is 2^6 = 64 const MAX_RADER_PRIME_FACTOR: usize = 23; // don't use Raders if the inner fft length has prime factor larger than this -const RADIX4_USE_BUTTERFLY32_FROM: u32 = 18; // Use length 32 butterfly starting from this power of 2 /// A Recipe is a structure that describes the design of a FFT, without actually creating it. /// It is used as a middle step in the planning process. @@ -663,14 +662,10 @@ impl FftPlannerSse { 1 => 2, 2 => 4, 3 => 8, - // main case: if len is a power of 4, use a base of 16, otherwise use a base of 8 + // main case: if len is a power of 4, use a base of 16, otherwise use a base of 32 _ => { if p2 % 2 == 1 { - if p2 >= RADIX4_USE_BUTTERFLY32_FROM { - 32 - } else { - 8 - } + 32 } else { 16 } From 9498352a90631ea658b3745b092d98ed63a5d452 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Mon, 26 Feb 2024 20:52:10 -0800 Subject: [PATCH 12/13] Use rotate45/135 for f32 size 16 --- src/sse/sse_butterflies.rs | 12 ++++-------- src/wasm_simd/wasm_simd_butterflies.rs | 12 ++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/sse/sse_butterflies.rs b/src/sse/sse_butterflies.rs index 64128c5b..e537a1d8 100644 --- a/src/sse/sse_butterflies.rs +++ b/src/sse/sse_butterflies.rs @@ -2288,9 +2288,7 @@ pub struct SseF32Butterfly16 { bf4: SseF32Butterfly4, twiddles_packed: [__m128; 6], twiddle1: __m128, - twiddle2: __m128, twiddle3: __m128, - twiddle6: __m128, twiddle9: __m128, } @@ -2323,9 +2321,7 @@ impl SseF32Butterfly16 { pack_32(tw6, tw9), ], twiddle1: pack_32(tw1, tw1), - twiddle2: pack_32(tw2, tw2), twiddle3: pack_32(tw3, tw3), - twiddle6: pack_32(tw6, tw6), twiddle9: pack_32(tw9, tw9), } } @@ -2409,11 +2405,11 @@ impl SseF32Butterfly16 { let [in2, in3] = load(2); let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); - tmp2[1] = SseVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[1] = self.bf4.rotate.rotate_both_45(tmp2[1]); tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); - tmp2[3] = SseVector::mul_complex(tmp2[3], self.twiddle6); + tmp2[3] = self.bf4.rotate.rotate_both_135(tmp2[3]); tmp3[1] = SseVector::mul_complex(tmp3[1], self.twiddle3); - tmp3[2] = SseVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[2] = self.bf4.rotate.rotate_both_135(tmp3[2]); tmp3[3] = SseVector::mul_complex(tmp3[3], self.twiddle9); // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill @@ -2421,7 +2417,7 @@ impl SseF32Butterfly16 { let tmp0 = self.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); tmp1[1] = SseVector::mul_complex(tmp1[1], self.twiddle1); - tmp1[2] = SseVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[2] = self.bf4.rotate.rotate_both_45(tmp1[2]); tmp1[3] = SseVector::mul_complex(tmp1[3], self.twiddle3); //////////////////////////////////////////////////////////// diff --git a/src/wasm_simd/wasm_simd_butterflies.rs b/src/wasm_simd/wasm_simd_butterflies.rs index e756e132..88873163 100644 --- a/src/wasm_simd/wasm_simd_butterflies.rs +++ b/src/wasm_simd/wasm_simd_butterflies.rs @@ -2414,9 +2414,7 @@ pub struct WasmSimdF32Butterfly16 { bf4: WasmSimdF32Butterfly4, twiddles_packed: [v128; 6], twiddle1: v128, - twiddle2: v128, twiddle3: v128, - twiddle6: v128, twiddle9: v128, } @@ -2453,9 +2451,7 @@ impl WasmSimdF32Butterfly16 { pack_32(tw6, tw9), ], twiddle1: pack_32(tw1, tw1), - twiddle2: pack_32(tw2, tw2), twiddle3: pack_32(tw3, tw3), - twiddle6: pack_32(tw6, tw6), twiddle9: pack_32(tw9, tw9), } } @@ -2564,11 +2560,11 @@ impl WasmSimdF32Butterfly16 { let [in2, in3] = Self::load_parallel_chunk(&buffer, 2); let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); - tmp2[1] = mul_complex_f32(tmp2[1], self.twiddle2); + tmp2[1] = self.bf4.rotate.rotate_both_45(tmp2[1]); tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); - tmp2[3] = mul_complex_f32(tmp2[3], self.twiddle6); + tmp2[3] = self.bf4.rotate.rotate_both_135(tmp2[3]); tmp3[1] = mul_complex_f32(tmp3[1], self.twiddle3); - tmp3[2] = mul_complex_f32(tmp3[2], self.twiddle6); + tmp3[2] = self.bf4.rotate.rotate_both_135(tmp3[2]); tmp3[3] = mul_complex_f32(tmp3[3], self.twiddle9); // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill @@ -2576,7 +2572,7 @@ impl WasmSimdF32Butterfly16 { let tmp0 = self.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); tmp1[1] = mul_complex_f32(tmp1[1], self.twiddle1); - tmp1[2] = mul_complex_f32(tmp1[2], self.twiddle2); + tmp1[2] = self.bf4.rotate.rotate_both_45(tmp1[2]); tmp1[3] = mul_complex_f32(tmp1[3], self.twiddle3); // Size-4 FFTs down each pair of transposed columns, storing them as soon as we're done with them From a83161f511920f709dcf6d3bccce104f5ca580f0 Mon Sep 17 00:00:00 2001 From: Elliott Mahler Date: Mon, 26 Feb 2024 21:00:14 -0800 Subject: [PATCH 13/13] Use rotate45/135 for neon butterfly16 --- src/neon/neon_butterflies.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 64ed6996..2e76b634 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -2307,9 +2307,7 @@ pub struct NeonF32Butterfly16 { bf4: NeonF32Butterfly4, twiddles_packed: [float32x4_t; 6], twiddle1: float32x4_t, - twiddle2: float32x4_t, twiddle3: float32x4_t, - twiddle6: float32x4_t, twiddle9: float32x4_t, } @@ -2342,9 +2340,7 @@ impl NeonF32Butterfly16 { pack_32(tw6, tw9), ], twiddle1: pack_32(tw1, tw1), - twiddle2: pack_32(tw2, tw2), twiddle3: pack_32(tw3, tw3), - twiddle6: pack_32(tw6, tw6), twiddle9: pack_32(tw9, tw9), } } @@ -2429,11 +2425,11 @@ impl NeonF32Butterfly16 { let [in2, in3] = load(2); let mut tmp2 = self.bf4.perform_parallel_fft_direct(in2); let mut tmp3 = self.bf4.perform_parallel_fft_direct(in3); - tmp2[1] = NeonVector::mul_complex(tmp2[1], self.twiddle2); + tmp2[1] = self.bf4.rotate.rotate_both_45(tmp2[1]); tmp2[2] = self.bf4.rotate.rotate_both(tmp2[2]); - tmp2[3] = NeonVector::mul_complex(tmp2[3], self.twiddle6); + tmp2[3] = self.bf4.rotate.rotate_both_135(tmp2[3]); tmp3[1] = NeonVector::mul_complex(tmp3[1], self.twiddle3); - tmp3[2] = NeonVector::mul_complex(tmp3[2], self.twiddle6); + tmp3[2] = self.bf4.rotate.rotate_both_135(tmp3[2]); tmp3[3] = NeonVector::mul_complex(tmp3[3], self.twiddle9); // Do these last, because fewer twiddles means fewer temporaries forcing the above data to spill @@ -2441,7 +2437,7 @@ impl NeonF32Butterfly16 { let tmp0 = self.bf4.perform_parallel_fft_direct(in0); let mut tmp1 = self.bf4.perform_parallel_fft_direct(in1); tmp1[1] = NeonVector::mul_complex(tmp1[1], self.twiddle1); - tmp1[2] = NeonVector::mul_complex(tmp1[2], self.twiddle2); + tmp1[2] = self.bf4.rotate.rotate_both_45(tmp1[2]); tmp1[3] = NeonVector::mul_complex(tmp1[3], self.twiddle3); ////////////////////////////////////////////////////////////