Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP32 FlashAttention #781

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

FP32 FlashAttention #781

wants to merge 1 commit into from

Conversation

ssiu
Copy link

@ssiu ssiu commented Oct 20, 2024

Overview

In this PR we implement FlashAttention forward + backward kernels for FP32 training.

All results were tested on V100.

For B = 4, T = 1024, C = 768, NH = 12:

attention_forward4 (ms) flash_attention_forward (ms) speedup
1.64 1.15 1.43x
attention_backward8 (ms) flash_attention_backward (ms) speedup
3.06 2.86 1.07x

96A09667-0FD4-4105-A5C5-7E7A5CF76A04

Requirements

Shared memory >= 64KB, so should work on all GPUs with SM >= 70.

FP32 end-to-end training

Training was done with B = 4, T = 1024, C = 768, NH = 12.

We use:

  • qkvr for inp (since we need inp for the backward pass)
  • l_att for L (m + log l)
  • l_fch for D (rowsum(dO * O))
attention_forward4 + attention_backward8 (ms) flash_attention_forward + flash_attention_backward (ms) speedup
Total average iteration time 337.81 330.88 1.02x
Final loss 3.49 3.51

For some reason, training with flash attention kernels results in a slightly higher loss.

Long context benchmark

We also tested long context performance by fixing B = 4, C = 768, NH = 12.

T attention_forward4 (ms) flash_attention_forward (ms) speedup
1024 1.64 1.15 1.43x
2048 6.11 3.78 1.62x
4096 24.62 13.53 1.82x
 T attention_backward8 (ms) flash_attention_backward (ms) speedup
1024 3.06 2.86 1.07x
2048 11.08 10.14 1.09x
4096 44.69 38.54 1.16x

82E782D7-CD25-45ED-9F55-08E0947A1237

52947521-EB19-4A1E-9088-072F37468A43

Improvements

We can improve the kernels further by permuting the shared memory layout to further minimize bank conflicts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant