Skip to content

Commit

Permalink
Speed up FFT by removing all lists and special casing powers of 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Dec 23, 2020
1 parent 3b24f8f commit 7fc0f11
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions examples/fft.dx
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ def butterfly_ixs (j:Int) (pow2:Int) : (n & n) =
(k@n, (k + pow2)@n)

def power_of_2_fft (direction: Direction) (x: n=>Complex) : n=>Complex =
-- input size must be a power of 2.
-- Could enforce this with function-level types like (x: (2^log2n)=>Complex)).
-- Input size must be a power of 2.
-- Could enforce this with tables-as-index-sets like:
-- (x: (log2n=>(Fin 2))=>Complex)).
dir_const = case direction of
Forward -> -pi
Backward -> pi

log2n = intlog2 (size n)
halfn = idiv (size n) 2

snd $ withState x \ref.
ans = snd $ withState x \ref.
for i:(Fin log2n).
pow2 = intpow 2 (ordinal i)
copy = get ref
Expand All @@ -52,7 +53,11 @@ def power_of_2_fft (direction: Direction) (x: n=>Complex) : n=>Complex =
ref!a := copy.(j'@n) + v
ref!b := copy.(j'@n) - v

def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Complex) =
case direction of
Forward -> ans
Backward -> ans / (IToF (size n))

def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Complex) =
-- Convolve by pointwise multiplication in the Fourier domain.
convolved_size = (size n) + (size m) - 1
working_size = nextpow2 convolved_size
Expand All @@ -62,8 +67,7 @@ def convolve_complex (u:n=>Complex) (v:m=>Complex) : ({ovals:n | padding:m }=>Co
spectral_v = power_of_2_fft Forward v_padded
spectral_conv = for i. spectral_u.i * spectral_v.i
padded_conv = power_of_2_fft Backward spectral_conv
us = slice padded_conv 0 {ovals:n | padding:m }
us / (IToF working_size) -- Todo: move into fft
slice padded_conv 0 {ovals:n | padding:m }

def convolve (u:n=>Float) (v:m=>Float) : ({ovals:n | padding:m }=>Float) =
u' = for i. MkComplex u.i 0.0
Expand All @@ -76,27 +80,36 @@ def convolve (u:n=>Float) (v:m=>Float) : ({ovals:n | padding:m }=>Float) =
'## FFT Interface

def fft (x: n=>Complex): n=>Complex =
-- Bluestein's algorithm for FFT on any size of array.
-- Todo: short circuit for powers of two.
im = MkComplex 0.0 1.0
wks = for i:n.
ks = IToF $ ordinal i
exp $ (-im) * (MkComplex (pi * (sq ks) / (IToF (size n))) 0.0)
xq = for i. x.i * wks.i
first = exp (-im * MkComplex (pi * (IToF (size n))) 0.0)
sn = (size n)
rwks:(Fin sn)=>Complex = for i. wks.(((size n) - 1 - ordinal i)@n)
nm1 = (size n) - 1
wq' = concat [AsList _ [first], AsList _ rwks, AsList _ (slice wks 1 (Fin nm1))]
(AsList _ wq'') = wq'
wq = for i. complex_conj wq''.i
conved = convolve_complex xq wq
convslice = slice conved (size n) n
for i. wks.i * convslice.i
if isPowerOf2 (size n)
then power_of_2_fft Forward x
else
-- Bluestein's algorithm for FFT on any size of array.
im = MkComplex 0.0 1.0
wks = for i.
i_squared = IToF $ sq $ ordinal i
exp $ (-im) * (MkComplex (pi * i_squared / (IToF (size n))) 0.0)

-- Turns sequence 12345 into 543212345.
-- I would break this into a helper function, but I don't know
-- how to type it.
backwards_and_forwards = Fin (2 * (size n) - 1)
baf = for i:backwards_and_forwards.
case ordinal i < size n of
True -> wks.(((size n) - 1 - ordinal i)@n)
False -> wks.(((ordinal i) - (size n) + 1)@n)

xq = for i. x.i * wks.i
baf_conj = for i. complex_conj baf.i
conved = convolve_complex xq baf_conj
convslice = slice conved (size n - 1) n
for i. wks.i * convslice.i

def ifft (xs: n=>Complex): n=>Complex =
fo = fft (for i. complex_conj xs.i)
for i. (complex_conj fo.i) / (IToF (size n))
if isPowerOf2 (size n)
then power_of_2_fft Backward xs
else
fo = fft (for i. complex_conj xs.i)
for i. (complex_conj fo.i) / (IToF (size n))

def fft_real (x: n=>Float): n=>Complex = fft for i. MkComplex x.i 0.0
def ifft_real (x: n=>Float): n=>Complex = ifft for i. MkComplex x.i 0.0
Expand Down

0 comments on commit 7fc0f11

Please sign in to comment.