[kernel] Refactor FA kernel to be FA_transV when possible. #568
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Flash Attention transpose_V variant is significantly faster than the non transpose_V variant. This is due to many MM intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance vastly. This also makes FP8 faster than FP16. I have tested that it indeed improves SDXL performance on FP8, making FP8 faster than our FP16 model.
I have also tested/confirmed that, if we do not find any producers that we can fuse with, it seem to re-fuse back into the attention. Hence, the worst performance it will get is same as before we un-split the transpose.
For some data on a microbenchmark with real size from SDXL:
Additionally, this PR also moves the reduction dimension of attention to the fastest dimension. This is preferable because many optimization passes expects reduction dims to be fastest dims, and will match our lowerings pass from IREE more.