-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes default value of softmax_scale
in PhiFlashAttention2
.
#28537
Conversation
>>> model = AutoModelForCausalLM.from_pretrained("susnato/phi-2") | ||
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2") | ||
>>> model = AutoModelForCausalLM.from_pretrained("phi-2") | ||
>>> tokenizer = AutoTokenizer.from_pretrained("phi-2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for changing these checkpoints. 🙌
I was about to open a PR to change these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problems!
attn_output = self._flash_attention_forward( | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, was this the reason for the issue regarding fine-tuning?
Now I am curious how the FA tests were passing before...
Anyway thanks a lot for fixing this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope it is fixed, at least, I am now able to see the same fine-tuning loss with/without flash-attention.
We pre-trained the Phi models using 1 / sqrt(head_dim)
as the softmax scale, and flash-attention uses the very same value when softmax_scale=None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW just to be sure, could you please run all flash attention tests(for phi
) to check if they are passing or not.
RUN_SLOW=1 pytest -m flash_attn_test tests/models/phi --verbose
Should not be necessary since it's already fixing the fine-tuning issue, but just to be sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just ran and everything has passed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, our CIs don't test flash attention, bit of a pity!
The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this fix! Very tricky indeed
attn_output = self._flash_attention_forward( | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, our CIs don't test flash attention, bit of a pity!
No problems! Thanks for the merge! |
Thanks very much @gugarosa for the deep dive and the fix! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ingface#28537) * fix(phi): Phi does not use softmax_scale in Flash-Attention. * chore(docs): Update Phi docs.
…ingface#28537) * fix(phi): Phi does not use softmax_scale in Flash-Attention. * chore(docs): Update Phi docs.
What does this PR do?
Phi has never used
softmax_scale=1.0
with Flash-Attention, so the default is being moved toNone
. This tentatively fixes any issue regarding fine-tuning Phi-based checkpoints when Flash-Attention 2 is turned on.Documentation is also updated to reflect the official Phi checkpoints.
Fixes #28488 (tentative)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @susnato