Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[kernel] Refactor FA kernel to be FA_transV when possible. (#568)
Flash Attention transpose_V variant is significantly faster than the non transpose_V variant. This is due to many MM intrinsics being mmtb by default. Hence, doing FA transpose_V will allow for better/more contiguous reads from shared memory to register, improving the attention performance vastly. This also makes FP8 faster than FP16. I have tested that it indeed improves SDXL performance on FP8, making FP8 faster than our FP16 model. I have also tested/confirmed that, if we do not find any producers that we can fuse with, it seem to re-fuse back into the attention. Hence, the worst performance it will get is same as before we un-split the transpose. For some data on a microbenchmark with real size from SDXL: ``` (B0, B1, M, K1, K2, N): (2, 10, 4096, 64, 4096, 64) Over 100 runs: FP16 non transpose: 22.7 ms FP8 non transpose: 23.8 ms FP16 transpose: 20.1 ms FP8 transpose: 17.5 ms ``` Additionally, this PR also moves the reduction dimension of attention to the fastest dimension. This is preferable because many optimization passes expects reduction dims to be fastest dims, and will match our lowerings pass from IREE more. Signed-off-by: Stanley Winata <stanley.winata@amd.com>
- Loading branch information