-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes default value of softmax_scale
in PhiFlashAttention2
.
#28537
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -506,7 +506,7 @@ def forward( | |
value_states = value_states.to(target_dtype) | ||
|
||
attn_output = self._flash_attention_forward( | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None | ||
) | ||
Comment on lines
508
to
510
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, was this the reason for the issue regarding fine-tuning? Anyway thanks a lot for fixing this! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope it is fixed, at least, I am now able to see the same fine-tuning loss with/without flash-attention. We pre-trained the Phi models using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW just to be sure, could you please run all flash attention tests(for
Should not be necessary since it's already fixing the fine-tuning issue, but just to be sure. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just ran and everything has passed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, our CIs don't test flash attention, bit of a pity! |
||
|
||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for changing these checkpoints. 🙌
I was about to open a PR to change these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problems!