Closed
Description
For BS=1, the gen phase flash_attn_vec_ext_f32
kernel is launched with a constant parallel_blocks
value of 4. Check code.
However, parallel_blocks = 4
causes poor occupancy on GPU.
Consider following models. The current occupancy is far below what is achievable if parallel_blocks
value is increased.
Model | num_heads | head_dim | occupancy with PB=4 on RTX 4090 | achievable occupancy with optimal PB value on RTX 4090 |
---|---|---|---|---|
Llama 3B | 24 | 128 | 0.06 | 0.25 |
Llama 8B | 32 | 128 | 0.08 | 0.25 |
Qwen 1.5B | 12 | 128 | 0.03 | 0.25 |
Qwen 7B | 28 | 128 | 0.07 | 0.25 |
I have a change that addresses this issue and it shows improvement in gen phase performance by up to 14%.
Metadata
Metadata
Assignees
Labels
No labels