-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support AMD ROCm on FlashAttention 2 #1010
Conversation
rocking5566
commented
Jun 26, 2024
•
edited
Loading
edited
- This PR implement the AMD / ROCm version of c++ flash api
- mha_fwd
- mha_varlen_fwd
- mha_bwd
- mha_varlen_bwd
- The kernel implementation comes from composable kernel
- The c++ api is same as original version. Hence, python interface can be used in common.
Sync with test_flash_attn.py
Ck tile/flash attention
Use same python as build flash-attn to generate ck kernel
[WIP] update to latest ck
I got this error during the build fmha_bwd_d128_fp16_batch for gfx1100: It was due to _Float16 not implemented in the instantiation of: The template doesn't accepts T=float16, but accepts T=bf16 with N==2 or 4: Is this expected? I saw the allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"] |
I also have an error when compiling for 7900xtx (gfx1100). Write, does flash-attention support this card? |
They don't support it yet, native should mean for CPU. I digged into it, here are my findings.
Here is the problem template instanciation I figured out, hope this could help:
I'll keep digging into it. My plan is to generate a group a kernels fitting with 7900xtx, will post my result if I have any findings. |
I tried this on a known good configuration, using TRL I am able to run it without flash attention, and I am able to run it with the ROCm version of flash attention. But using this PR - I get an error. I also repro this same error in axolotl, so it's not a specific problem with TRL.
Also please note that - we were able to test Successfully when using this branch in the ROCm repo. It is only when we merge that to the main branch of the DAO-AILab repo that this error presents. We have independently reproduced this error on several machines, and several environments. @rocking5566 we would be happy to have a call and troubleshoot this with you, if you would like to reach out. eric@tensorwave.com |
@minzhezhou Thanks for your time. We only support mi200 & mi300 at this time. Thus we put gfx90a/gfx94x in the |
Hi @poyenc, thanks for the reminder. Do you mean it is technically impossible to make it work for navi or it is not on the official roadmap yet? |
@minzhezhou thanks for your attention. targets other than gfx90a & gfx94x is not on the roadmap yet. currently we are focusing on mi300 platforms. |
now that the conflicts are resolved, can we merge this? thanks! |
I have mi50 and mi100 , looking forward to gfx906 and gfx908 support, they are so much cheaper |
You can use this script to query arch and exclude any archs not well being testified (only gfx942 tested)
|
gfx906 and gfx908 could at least compile, they have warp_size = 64. |
Current FA could compile successfully in MI100. |
Thats excellent news! I cannot wait to try it on my mi100s |
mi100 worked , but mi50 failed on all test |
That is why we only claim MI200 & MI300 are officially support. |
I installed rocm 6.1x on mi50, all test failed -.- hoping for mi50 support~ |
I am wondering why kv cache/ paged attention API |
fwd_kvcache is in our roadmap. This will coming soon. |
Thanks for kind reply! |