diff --git a/examples/fft.dx b/examples/fft.dx index cba814f86..61329584f 100644 --- a/examples/fft.dx +++ b/examples/fft.dx @@ -28,8 +28,9 @@ def butterfly_ixs (j:Int) (pow2:Int) : (n & n) = (k@n, (k + pow2)@n) def power_of_2_fft (direction: Direction) (x: n=>Complex) : n=>Complex = - -- input size must be a power of 2. - -- Could enforce this with function-level types like (x: (2^log2n)=>Complex)). + -- Input size must be a power of 2. + -- Could enforce this with tables-as-index-sets like: + -- (x: (log2n=>(Fin 2))=>Complex)). dir_const = case direction of Forward -> -pi Backward -> pi @@ -37,7 +38,7 @@ def power_of_2_fft (direction: Direction) (x: n=>Complex) : n=>Complex = log2n = intlog2 (size n) halfn = idiv (size n) 2 - snd $ withState x \ref. + ans = snd $ withState x \ref. for i:(Fin log2n). pow2 = intpow 2 (ordinal i) copy = get ref @@ -52,7 +53,11 @@ def power_of_2_fft (direction: Direction) (x: n=>Complex) : n=>Complex = ref!a := copy.(j'@n) + v ref!b := copy.(j'@n) - v -def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Complex) = + case direction of + Forward -> ans + Backward -> ans / (IToF (size n)) + +def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Complex) = -- Convolve by pointwise multiplication in the Fourier domain. convolved_size = (size n) + (size m) - 1 working_size = nextpow2 convolved_size @@ -62,8 +67,7 @@ def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Co spectral_v = power_of_2_fft Forward v_padded spectral_conv = for i. spectral_u.i * spectral_v.i padded_conv = power_of_2_fft Backward spectral_conv - us = slice padded_conv 0 {ovals:n | padding:m } - us / (IToF working_size) -- Todo: move into fft + slice padded_conv 0 {ovals:n | padding:m } def convolve (u:n=>Float) (v:m=>Float) : ({ovals:n | padding:m }=>Float) = u' = for i. MkComplex u.i 0.0 @@ -76,27 +80,36 @@ def convolve (u:n=>Float) (v:m=>Float) : ({ovals:n | padding:m }=>Float) = '## FFT Interface def fft (x: n=>Complex): n=>Complex = - -- Bluestein's algorithm for FFT on any size of array. - -- Todo: short circuit for powers of two. - im = MkComplex 0.0 1.0 - wks = for i:n. - ks = IToF $ ordinal i - exp $ (-im) * (MkComplex (pi * (sq ks) / (IToF (size n))) 0.0) - xq = for i. x.i * wks.i - first = exp (-im * MkComplex (pi * (IToF (size n))) 0.0) - sn = (size n) - rwks:(Fin sn)=>Complex = for i. wks.(((size n) - 1 - ordinal i)@n) - nm1 = (size n) - 1 - wq' = concat [AsList _ [first], AsList _ rwks, AsList _ (slice wks 1 (Fin nm1))] - (AsList _ wq'') = wq' - wq = for i. complex_conj wq''.i - conved = convolve_complex xq wq - convslice = slice conved (size n) n - for i. wks.i * convslice.i + if isPowerOf2 (size n) + then power_of_2_fft Forward x + else + -- Bluestein's algorithm for FFT on any size of array. + im = MkComplex 0.0 1.0 + wks = for i. + i_squared = IToF $ sq $ ordinal i + exp $ (-im) * (MkComplex (pi * i_squared / (IToF (size n))) 0.0) + + -- Turns sequence 12345 into 543212345. + -- I would break this into a helper function, but I don't know + -- how to type it. + backwards_and_forwards = Fin (2 * (size n) - 1) + baf = for i:backwards_and_forwards. + case ordinal i < size n of + True -> wks.(((size n) - 1 - ordinal i)@n) + False -> wks.(((ordinal i) - (size n) + 1)@n) + + xq = for i. x.i * wks.i + baf_conj = for i. complex_conj baf.i + conved = convolve_complex xq baf_conj + convslice = slice conved (size n - 1) n + for i. wks.i * convslice.i def ifft (xs: n=>Complex): n=>Complex = - fo = fft (for i. complex_conj xs.i) - for i. (complex_conj fo.i) / (IToF (size n)) + if isPowerOf2 (size n) + then power_of_2_fft Backward xs + else + fo = fft (for i. complex_conj xs.i) + for i. (complex_conj fo.i) / (IToF (size n)) def fft_real (x: n=>Float): n=>Complex = fft for i. MkComplex x.i 0.0 def ifft_real (x: n=>Float): n=>Complex = ifft for i. MkComplex x.i 0.0