Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Fourier Transform #313

Merged
merged 7 commits into from
Jan 26, 2021
Merged

Fast Fourier Transform #313

merged 7 commits into from
Jan 26, 2021

Conversation

duvenaud
Copy link
Contributor

@duvenaud duvenaud commented Dec 8, 2020

The inner loop is now parallelizable by the compiler, since it uses yieldAccum over AddMonoid Complex.

It would be nice to be able to enforce that the input size is a power of two using the type system. I worked out how to do this in a way that typechecks: ((Fin pow)=>m)=>a expresses arrays of size m^pow. But this crashes currently due to #146.

At some point, I'm hoping to make a blog post comparing it to the Fhutark FFT example that I initially based this one off of. The main emphasis would be on the additional safety that comes from types (and maybe more generality once Complex can be parameterized by its floating-point representation). But the coding style is also very different - just like when writing the fluidsim, I originally started writing lots of helper functions for vectorized operations, like gather, scatter, and different flavors of zip / unzip. But then I realized it would be easier just to write fused for loops. So the code for the same function ends up looking very different.

Misc notes:

  • Eventually I'm hoping to move all the @s to marked unsafe one-liners, following @apaszke 's style, or remove them by using tables as index sets.
  • These functions are linear, but I don't think the compiler will be able to recognize that.
  • I'm open to moving most of the utility functions to the prelude.

@google-cla google-cla bot added the cla: yes label Dec 8, 2020
@dan-zheng
Copy link
Collaborator

dev branch has been merged into main, now that plotting has been revamped. dev branch will be deleted now - we can continue development on main.

I'll change the base branch of this PR to main now.

@dan-zheng dan-zheng changed the base branch from dev to main December 18, 2020 23:00
@duvenaud duvenaud force-pushed the fft branch 2 times, most recently from ad0918b to 7fc0f11 Compare December 23, 2020 23:01
@duvenaud duvenaud changed the title [WIP] Fast Fourier Transform Fast Fourier Transform Jan 4, 2021
@duvenaud
Copy link
Contributor Author

@dougalm @apaszke I've polished this and added some more tests - I think it's ready to be reviewed now.

Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let me know if you want to apply the minor comments, of if you'd like to merge this.

I'd also like to better understand the data flow in this algorithm to see if we could eliminate some of the unsafe indexing, but we can also do that later.

examples/fft.dx Show resolved Hide resolved
def listToTable ((AsList n xs): List a) : (Fin n)=>a = xs

def odd_sized_palindrome (mid:a) (seq:n=>a) :
({backward:n | mid:Unit | zforward:n}=>a) = -- Alphabetical order matters here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... Right. I think that the order being alphabetical is a bit of a coincidence and in reality it is left "implementation defined"? I guess that this is fine for now, although things like that might deserve their own user-space defined index sets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is bad, and when we get user-space defined index sets I'd like to do what you suggest.

examples/fft.dx Outdated

def reflect (i:n) : n =
s = size n
(s - 1 - ordinal i)@n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unsafeFromOrdinal should be good here

Copy link
Contributor Author

@duvenaud duvenaud Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I also got rid of all the rest of the @ in this file. Surprisingly to me, it made compilation much faster.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the IR in quite a while, but when I did recently I realized that since we've made errors use IO for printing in the frontend, any invocation of @ got really complicated. cc @dougalm

examples/fft.dx Outdated
for i:(Fin log2n).
ipow2 = intpow 2 (ordinal i)
copy = get refOuter
refOuter := yieldAccum (AddMonoid Complex) \ref.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updates to State should be atomic, so you can keep getting from refOuter in the body FWIW.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I didn't realize that. Thanks for the tip, it made the code shorter, avoids an extra copy, and made it a bit more readable, imo.

@duvenaud
Copy link
Contributor Author

LGTM. Let me know if you want to apply the minor comments, of if you'd like to merge this.

I addressed the ones I could, I'd like to merge now, I think.

I'd also like to better understand the data flow in this algorithm to see if we could eliminate some of the unsafe indexing, but we can also do that later.

There are 4 remaining unsafe indexes besides reflect. I think I know how to get rid of 2 of them now, and the other 2 once we have power-of-2 index sets. But for now I just gave them better names and grouped them a little more sensibly so it's easier to understand what they do.

for i:(Fin log2n).
ipow2 = intpow 2 (ordinal i)
xRef := yieldAccum (AddMonoid Complex) \bufRef.
for j:(Fin halfn). -- Executes in parallel.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You wish 😆 It reads from a stateful reference, so I doubt we would ever try to parallelize this. Of course we could (and should) optimize it so that reads are ok, but it's not implemented at the moment!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think it's an effect of me advocating for moving the copy part into the loop 😕

@apaszke apaszke merged commit 718a86c into google-research:main Jan 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants