-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slowdown in Training Speed Due to SDPA Mask Fix in Version 4.40.0 #30461
Comments
cc @fxmarty |
Hi @achew010, thank you for the report. Two PRs may be at play here, #30127 and #30070. Long story short,
As you can see here, a check is done on
This was already the case for transformers<=4.39 (and also in Mistral public code release). Unfortunately, apart from the original HazyResearch/flash-attn implementation ( It is still unclear to me why you see this regression, given that transformers==4.39 used to always use |
Thanks for giving some context @fxmarty, from your explanation i have a better understanding of what caused the regression.
Indeed like what you said, by playing around with this context manager and setting them to use the same backend (only
The introduction of a sliding window here will influence what backend the SDPA kernel will use. In my setup, having a This clarifies everything, thanks alot for the help! |
System Info
Hi,
I have been doing some peft tuning with Mistral/Mixtral and recently I observed a slowdown in training since the release of version 4.40.0. I narrowed it down to this fix in 40eb6d6 where the sliding window is now specified in
_prepare_4d_causal_attention_mask_for_sdpa
.I ran a simple training job and the training statistics produced 2 different sets of throughputs
When my training sequence length is within/on the sliding window threshold (i.e. seqlen = 4096, window = 4096), it should fall back to the SDPA kernel to handle the causal mask. I also dont see the computation savings at sequence length=8192 from the introduction of sliding window attention compared to if there wasnt a windowed causal mask at all (calculating attention across all 8192 tokens).
Below is a dummy example showing that simply not passing the causal mask into pytorch's SDPA function (allowing the kernel to handle the causal mask itself) vs specifying the sliding window, has a significant impact on the processing speed of the kernel.
Is this slowdown something we should expect from using the SDPA module with the current fix?
I attached a simple script to reproduce the issue
System Info
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Throughput should remain the same for sequence lengths lower than the window size for SPDA
Throughput should be slightly faster (from lesser computations in local attention) than regular attention (when no sliding window is specified in causal mask) for longer sequence lengths
The text was updated successfully, but these errors were encountered: