diff --git a/src/neon/neon_butterflies.rs b/src/neon/neon_butterflies.rs index 2e76b63..4ba2831 100644 --- a/src/neon/neon_butterflies.rs +++ b/src/neon/neon_butterflies.rs @@ -427,6 +427,7 @@ pub struct NeonF32Butterfly3 { twiddle: float32x4_t, twiddle1re: float32x4_t, twiddle1im: float32x4_t, + twiddle2im: float32x4_t, } boilerplate_fft_neon_f32_butterfly!(NeonF32Butterfly3, 3, |this: &NeonF32Butterfly3<_>| this @@ -442,6 +443,7 @@ impl NeonF32Butterfly3 { let twiddle = unsafe { vld1q_f32([tw1.re, tw1.re, -tw1.im, -tw1.im].as_ptr()) }; let twiddle1re = unsafe { vmovq_n_f32(tw1.re) }; let twiddle1im = unsafe { vmovq_n_f32(tw1.im) }; + let twiddle2im = unsafe { vmovq_n_f32(-tw1.im) }; Self { direction, _phantom: std::marker::PhantomData, @@ -449,6 +451,7 @@ impl NeonF32Butterfly3 { twiddle, twiddle1re, twiddle1im, + twiddle2im, } } #[inline(always)] @@ -498,8 +501,7 @@ impl NeonF32Butterfly3 { // This is a Neon translation of the scalar 3-point butterfly let rev12 = reverse_complex_and_negate_hi_f32(value12); let temp12pn = self.rotate.rotate_hi(vaddq_f32(value12, rev12)); - let twiddled = vmulq_f32(temp12pn, self.twiddle); - let temp = vaddq_f32(value0x, twiddled); + let temp = vfmaq_f32(value0x, temp12pn, self.twiddle); let out12 = solo_fft2_f32(temp); let out0x = vaddq_f32(value0x, temp12pn); @@ -518,17 +520,14 @@ impl NeonF32Butterfly3 { // This is a Neon translation of the scalar 3-point butterfly let x12p = vaddq_f32(value1, value2); let x12n = vsubq_f32(value1, value2); - let sum = vaddq_f32(value0, x12p); - - let temp_a = vmulq_f32(self.twiddle1re, x12p); - let temp_a = vaddq_f32(temp_a, value0); + let temp = vfmaq_f32(value0, self.twiddle1re, x12p); let n_rot = self.rotate.rotate_both(x12n); - let temp_b = vmulq_f32(self.twiddle1im, n_rot); - let x1 = vaddq_f32(temp_a, temp_b); - let x2 = vsubq_f32(temp_a, temp_b); - [sum, x1, x2] + let x0 = vaddq_f32(value0, x12p); + let x1 = vfmaq_f32(temp, self.twiddle1im, n_rot); + let x2 = vfmaq_f32(temp, self.twiddle2im, n_rot); + [x0, x1, x2] } } @@ -545,6 +544,7 @@ pub struct NeonF64Butterfly3 { rotate: Rotate90F64, twiddle1re: float64x2_t, twiddle1im: float64x2_t, + twiddle2im: float64x2_t, } boilerplate_fft_neon_f64_butterfly!(NeonF64Butterfly3, 3, |this: &NeonF64Butterfly3<_>| this @@ -559,6 +559,7 @@ impl NeonF64Butterfly3 { let tw1: Complex = twiddles::compute_twiddle(1, 3, direction); let twiddle1re = unsafe { vmovq_n_f64(tw1.re) }; let twiddle1im = unsafe { vmovq_n_f64(tw1.im) }; + let twiddle2im = unsafe { vmovq_n_f64(-tw1.im) }; Self { direction, @@ -566,6 +567,7 @@ impl NeonF64Butterfly3 { rotate, twiddle1re, twiddle1im, + twiddle2im, } } @@ -594,16 +596,14 @@ impl NeonF64Butterfly3 { // This is a Neon translation of the scalar 3-point butterfly let x12p = vaddq_f64(value1, value2); let x12n = vsubq_f64(value1, value2); - let sum = vaddq_f64(value0, x12p); - - let temp_a = vfmaq_f64(value0, self.twiddle1re, x12p); + let temp = vfmaq_f64(value0, self.twiddle1re, x12p); let n_rot = self.rotate.rotate(x12n); - let temp_b = vmulq_f64(self.twiddle1im, n_rot); - let x1 = vaddq_f64(temp_a, temp_b); - let x2 = vsubq_f64(temp_a, temp_b); - [sum, x1, x2] + let x0 = vaddq_f64(value0, x12p); + let x1 = vfmaq_f64(temp, self.twiddle1im, n_rot); + let x2 = vfmaq_f64(temp, self.twiddle2im, n_rot); + [x0, x1, x2] } }