Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] Refactor FA kernel to be FA_transV when possible. #568

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

raikonenfnu
Copy link
Member

@raikonenfnu raikonenfnu commented Nov 19, 2024

Flash Attention transpose_V variant is significantly faster than the non transpose_V variant. This is due to many MM intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance vastly. This also makes FP8 faster than FP16. I have tested that it indeed improves SDXL performance on FP8, making FP8 faster than our FP16 model.

I have also tested/confirmed that, if we do not find any producers that we can fuse with, it seem to re-fuse back into the attention. Hence, the worst performance it will get is same as before we un-split the transpose.

For some data on a microbenchmark with real size from SDXL:

(B0, B1, M, K1, K2, N): (2, 10, 4096, 64, 4096, 64)
Over 100 runs:

FP16 non transpose:  22.7 ms
FP8 non transpose: 23.8 ms

FP16 transpose: 20.1 ms
FP8 transpose:  17.5 ms

Additionally, this PR also moves the reduction dimension of attention to the fastest dimension. This is preferable because many optimization passes expects reduction dims to be fastest dims, and will match our lowerings pass from IREE more.

@raikonenfnu raikonenfnu force-pushed the setTransposeProducerAttention branch 2 times, most recently from 3372458 to cf7bc9d Compare November 19, 2024 07:05
Flash Attention transpose_V variant is significantly faster than the non
transpose_V variant. This is due to many MM intrinsics being mmtb by
default. Hence, doing FA transpose_V will allow for better/more
contiguous reads from shared memory to register, improving the attention
performance vastly. This also makes FP8 faster than FP16.

Additionally, this PR also moves the reduction dimension of attention to
the fastest dimension. This is preferable because many optimization
passes expects reduction dims to be fastest dims, and will match our
lowerings pass from IREE more.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
Copy link
Contributor

@nithinsubbiah nithinsubbiah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@raikonenfnu raikonenfnu merged commit a7feae8 into nod-ai:main Nov 19, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants