-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlashAttention works with single GPU, but crash with accelerate DP on multiple GPU (FlashAttention only support fp16 and bf16 data type) #822
Comments
I'm not familiar with accelerate or how |
I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. - `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
|
>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True) ....
this doesn't work for me again, might be because I have. |
The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim). |
The error is before that, but it seems it's torch nightly, the |
I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version). |
Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout. Actually they are equivalent with this mapping:
But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".
I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:
So we need to update the error information, right ? |
I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format! |
System Info
Reproduction
The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error:
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
The text was updated successfully, but these errors were encountered: