Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
In this PR we implement FlashAttention forward + backward kernels for FP32 training.
All results were tested on V100.
For B = 4, T = 1024, C = 768, NH = 12:
Requirements
Shared memory >= 64KB, so should work on all GPUs with SM >= 70.
FP32 end-to-end training
Training was done with B = 4, T = 1024, C = 768, NH = 12.
We use:
For some reason, training with flash attention kernels results in a slightly higher loss.
Long context benchmark
We also tested long context performance by fixing B = 4, C = 768, NH = 12.
Improvements
We can improve the kernels further by permuting the shared memory layout to further minimize bank conflicts.