This repository is a fork of the FlashAttention main repo. It extends the official implementation to support FlashAttention with flexible head dimensions.
All configurations in FlashAttention-2 are supported. Besides, we have supported:
- FlashAttention-2 with QKHeadDim=32, VHeadDim=64
- FlashAttention-2 with QKHeadDim=64, VHeadDim=128
- FlashAttention-2 with QKHeadDim=96, VHeadDim=192
- FlashAttention-2 with QKHeadDim=128, VHeadDim=256
- FlashAttention-2 with QKHeadDim=192, VHeadDim=128
- FLashAttention-2 with not equal num_heads_k and num_heads_v, such as (num_heads_q, num_heads_k, num_heads_v) = (32, 4, 16)
For headdim not supported, you can use the autotuner to generate the implementation. Details are in autotuner.md
.
Feel free to tell us what else you need. We might support it soon. :)
The requirements is the same as FlashAttention-2
To install:
pip install flex-head-fa --no-build-isolation
Alternatively you can compile from source:
python setup.py install
The usage remains the same as FlashAttention-2. You only need to replace flash_attn
with flex_head_fa
, as shown below:
from flex_head_fa import flash_attn_func, flash_attn_with_kvcache
We are also developing FlexHeadFA based on the lastest FLashAttention-3. Currently, besides all configurations in FlashAttention-3, we also support
-
FlashAttention-3 with QKHeadDim=32, VHeadDim=64
-
FlashAttention-3 forward + FlashAttention-2 backward with QKHeadDim=128, VHeadDim=256 (FlashAttention-3 backward is under development)
Try it with:
cd hopper
python setup.py install
Usage:
from flash_attn_interface import flash_attn_func # FlashAttention-3 forward+backward
from flash_attn_interface import flash_attn_f3b2_func as flash_attn_func # FlashAttention-3 forward + FlashAttention-2 backward
We test the performance speedup compare to padding qk&v hidden_dim on A100.
We display FlexHeadFA speedup using these parameters:
- (qk dim, v_dim): (32,64), (64,128), (128,256); qk hidden dimension 2048 (i.e. 64, 32 or 16 heads).
- Sequence length 512, 1k, 2k, 4k, 8k, 16k.
- Batch size set to 16k / seqlen.
This new release of FlexHeadFA has been tested on several GPT-style models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue!