Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Fourier Transform #313

Merged
merged 7 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions examples/fft.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
'# Fast Fourier Transform
For arrays whose size is a power of 2, we use a radix-2 algorithm based
on the [Fhutark demo](https://github.com/diku-dk/fft/blob/master/lib/github.com/diku-dk/fft/stockham-radix-2.fut#L30).
For non-power-of-2 sized arrays, it uses
[Bluestein's Algorithm](https://en.wikipedia.org/wiki/Chirp_Z-transform),
which calls the power-of-2 FFT as a subroutine.


'## General helper functions

def isPowerOf2 (x:Int) : Bool =
-- A fast trick based on bitwise AND.
-- Note: The bitwise and operator (.&.)
-- is only defined for Byte, which is why
-- I use %and here.
-- A test below verifies that this works on
-- more than 8 bit integers.
-- TODO: Make (.&.) polymorphic.
if x == 0
then False
else 0 == %and x (x - 1)

def intpow [Mul a] (base:a) (power:Int) : a =
-- Todo: Use Knuth's power trick,
-- which can give an asymptotic speedup.
reduce one (*) (for a:(Fin power). base)

def intlog2 (x:Int) : Int =
yieldState (-1) \ansRef.
runState 1 \cmpRef.
while do
if x >= (get cmpRef)
then
ansRef := (get ansRef) + 1
cmpRef := %shl (get cmpRef) 1
True
else
False

def intpow2 (power:Int) : Int = %shl 1 power
apaszke marked this conversation as resolved.
Show resolved Hide resolved

def nextpow2 (x:Int) : Int = case isPowerOf2 x of
True -> x
False -> intpow2 (1 + intlog2 x)

def reflect (i:n) : n =
unsafeFromOrdinal n ((size n) - 1 - ordinal i)

def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs

def odd_sized_palindrome (mid:a) (seq:n=>a) :
({backward:n | mid:Unit | zforward:n}=>a) = -- Alphabetical order matters here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... Right. I think that the order being alphabetical is a bit of a coincidence and in reality it is left "implementation defined"? I guess that this is fine for now, although things like that might deserve their own user-space defined index sets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is bad, and when we get user-space defined index sets I'd like to do what you suggest.

-- Turns sequence 12345 into 543212345.
for i.
case i of
{|backward=i|} -> seq.(reflect i)
{|mid=i|} -> mid
{|zforward=i|} -> seq.i


'## Inner FFT functions

data FTDirection =
ForwardFT
InverseFT

def butterfly_ixs (j':halfn) (pow2:Int) : (n & n & n & n) =
-- Re-index at a finer frequency.
-- halfn must have half the size of n.
-- For explanation, see https://en.wikipedia.org/wiki/Butterfly_diagram
-- Note: with fancier index sets, this might be replacable by reshapes.
j = ordinal j'
k = ((idiv j pow2) * pow2 * 2) + mod j pow2
left_write_ix = unsafeFromOrdinal n k
right_write_ix = unsafeFromOrdinal n (k + pow2)

left_read_ix = unsafeFromOrdinal n j
right_read_ix = unsafeFromOrdinal n (j + size halfn)
(left_read_ix, right_read_ix, left_write_ix, right_write_ix)

def power_of_2_fft (direction: FTDirection) (x: n=>Complex) : n=>Complex =
-- Input size must be a power of 2.
-- Can enforce this with tables-as-index-sets like:
-- (x: (log2n=>(Fin 2))=>Complex)) once that's supported.
dir_const = case direction of
ForwardFT -> -pi
InverseFT -> pi

log2n = intlog2 (size n)
halfn = idiv (size n) 2

ans = yieldState x \xRef.
for i:(Fin log2n).
ipow2 = intpow 2 (ordinal i)
xRef := yieldAccum (AddMonoid Complex) \bufRef.
for j:(Fin halfn). -- Executes in parallel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You wish 😆 It reads from a stateful reference, so I doubt we would ever try to parallelize this. Of course we could (and should) optimize it so that reads are ok, but it's not implemented at the moment!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think it's an effect of me advocating for moving the copy part into the loop 😕

(left_read_ix, right_read_ix,
left_write_ix, right_write_ix) = butterfly_ixs j ipow2

-- Read one element from the last buffer, scaled.
angle = dir_const * (IToF $ mod (ordinal j) ipow2) / IToF ipow2
v = (get xRef!right_read_ix) * (MkComplex (cos angle) (sin angle))

-- Add and subtract it to the relevant places in the new buffer.
bufRef!left_write_ix += (get (xRef!left_read_ix)) + v
bufRef!right_write_ix += (get (xRef!left_read_ix)) - v

case direction of
ForwardFT -> ans
InverseFT -> ans / (IToF (size n))

def convolve_complex (u:n=>Complex) (v:m=>Complex) :
({orig_vals:n | padding:m }=>Complex) = -- Alphabetical order matters here.
-- Convolve by pointwise multiplication in the Fourier domain.
convolved_size = (size n) + (size m) - 1
working_size = nextpow2 convolved_size
u_padded = padTo (Fin working_size) zero u
v_padded = padTo (Fin working_size) zero v
spectral_u = power_of_2_fft ForwardFT u_padded
spectral_v = power_of_2_fft ForwardFT v_padded
spectral_conv = for i. spectral_u.i * spectral_v.i
padded_conv = power_of_2_fft InverseFT spectral_conv
slice padded_conv 0 {orig_vals:n | padding:m }

def convolve (u:n=>Float) (v:m=>Float) : ({orig_vals:n | padding:m }=>Float) =
u' = for i. MkComplex u.i 0.0
v' = for i. MkComplex v.i 0.0
ans = convolve_complex u' v'
for i.
(MkComplex real imag) = ans.i
real


'## FFT Interface

def fft (x: n=>Complex): n=>Complex =
if isPowerOf2 (size n)
then power_of_2_fft ForwardFT x
else
-- Bluestein's algorithm.
-- Converts the general FFT into a convolution,
-- which is then solved with calls to a power-of-2 FFT.
im = MkComplex 0.0 1.0
wks = for i.
i_squared = IToF $ sq $ ordinal i
exp $ (-im) * (MkComplex (pi * i_squared / (IToF (size n))) 0.0)

tailList = tail wks 1
back_and_forth = odd_sized_palindrome (head wks) (listToTable tailList)
xq = for i. x.i * wks.i
back_and_forth_conj = for i. complex_conj back_and_forth.i
convolution = convolve_complex xq back_and_forth_conj
convslice = slice convolution (size n - 1) n
for i. wks.i * convslice.i

def ifft (xs: n=>Complex): n=>Complex =
if isPowerOf2 (size n)
then power_of_2_fft InverseFT xs
else
unscaled_fft = fft (for i. complex_conj xs.i)
for i. (complex_conj unscaled_fft.i) / (IToF (size n))

def fft_real (x: n=>Float): n=>Complex = fft for i. MkComplex x.i 0.0
def ifft_real (x: n=>Float): n=>Complex = ifft for i. MkComplex x.i 0.0

def fft2 (x: n=>m=>Complex): n=>m=>Complex =
x' = for i. fft x.i
transpose for i. fft (transpose x').i

def ifft2 (x: n=>m=>Complex): n=>m=>Complex =
x' = for i. ifft x.i
transpose for i. ifft (transpose x').i

def fft2_real (x: n=>m=>Float): n=>m=>Complex = fft2 for i j. MkComplex x.i.j 0.0
def ifft2_real (x: n=>m=>Float): n=>m=>Complex = ifft2 for i j. MkComplex x.i.j 0.0

-------- Tests --------

a = for i. MkComplex [10.1, -2.2, 8.3, 4.5, 9.3].i 0.0
b = for i:(Fin 20) j:(Fin 70).
MkComplex (randn $ ixkey2 (newKey 0) i j) 0.0

:p a ~~ (ifft $ fft a)
> True
:p a ~~ (fft $ ifft a)
> True
:p b ~~ (ifft2 $ fft2 b)
> True
:p b ~~ (fft2 $ ifft2 b)
> True

-- Integer utility tests

:p map isPowerOf2 [0, 1, 2, 3, 256, 1024, 1024*1024, 1024*1024 + 1]
> [False, True, True, False, True, True, True, False]

:p intpow 0 0
> 1

:p intpow 0 1
> 0

:p intpow 1 0
> 1

:p intpow 1 1
> 1

:p intpow 2 1
> 2

:p intpow 2 3
> 8

:p map intlog2 [-1, 0, 1, 2, 3, 4, 5, 1023, 1024, 1025]
> [-1, -1, 0, 1, 1, 2, 2, 9, 10, 10]

:p map nextpow2 [-1, 0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025]
> [1, 1, 1, 2, 4, 4, 8, 8, 16, 1024, 1024, 2048]
2 changes: 1 addition & 1 deletion makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ example-names = mandelbrot pi sierpinski rejection-sampler \
regression brownian_motion particle-swarm-optimizer \
ode-integrator mcmc ctc raytrace particle-filter \
isomorphisms ode-integrator linear_algebra fluidsim \
sgd chol
sgd chol fft

test-names = uexpr-tests adt-tests type-tests eval-tests show-tests \
shadow-tests monad-tests io-tests exception-tests \
Expand Down