Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes default value of softmax_scale in PhiFlashAttention2. #28537

Merged
merged 2 commits into from
Jan 17, 2024
Merged

Fixes default value of softmax_scale in PhiFlashAttention2. #28537

merged 2 commits into from
Jan 17, 2024

Conversation

gugarosa
Copy link
Contributor

@gugarosa gugarosa commented Jan 16, 2024

What does this PR do?

  • Phi has never used softmax_scale=1.0 with Flash-Attention, so the default is being moved to None. This tentatively fixes any issue regarding fine-tuning Phi-based checkpoints when Flash-Attention 2 is turned on.

  • Documentation is also updated to reflect the official Phi checkpoints.

Fixes #28488 (tentative)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @susnato

Comment on lines -87 to +88
>>> model = AutoModelForCausalLM.from_pretrained("susnato/phi-2")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
>>> model = AutoModelForCausalLM.from_pretrained("phi-2")
>>> tokenizer = AutoTokenizer.from_pretrained("phi-2")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for changing these checkpoints. 🙌
I was about to open a PR to change these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problems!

Comment on lines 508 to 510
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, was this the reason for the issue regarding fine-tuning?
Now I am curious how the FA tests were passing before...

Anyway thanks a lot for fixing this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it is fixed, at least, I am now able to see the same fine-tuning loss with/without flash-attention.

We pre-trained the Phi models using 1 / sqrt(head_dim) as the softmax scale, and flash-attention uses the very same value when softmax_scale=None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW just to be sure, could you please run all flash attention tests(for phi) to check if they are passing or not.

RUN_SLOW=1 pytest -m flash_attn_test tests/models/phi --verbose

Should not be necessary since it's already fixing the fine-tuning issue, but just to be sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just ran and everything has passed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, our CIs don't test flash attention, bit of a pity!

@gugarosa gugarosa marked this pull request as ready for review January 16, 2024 18:37
@gugarosa
Copy link
Contributor Author

gugarosa commented Jan 16, 2024

The loss=0.0 error while fine-tuning with FP16 is another issue and I do have an ugly fix, but will look into it with more patience (and use a separate PR).

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this fix! Very tricky indeed

Comment on lines 508 to 510
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, our CIs don't test flash attention, bit of a pity!

@ArthurZucker ArthurZucker merged commit d93ef7d into huggingface:main Jan 17, 2024
19 checks passed
@gugarosa
Copy link
Contributor Author

No problems! Thanks for the merge!

@gugarosa gugarosa deleted the fix-phi-tune branch January 17, 2024 13:25
@younesbelkada
Copy link
Contributor

Thanks very much @gugarosa for the deep dive and the fix!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
…ingface#28537)

* fix(phi): Phi does not use softmax_scale in Flash-Attention.

* chore(docs): Update Phi docs.
AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
…ingface#28537)

* fix(phi): Phi does not use softmax_scale in Flash-Attention.

* chore(docs): Update Phi docs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

fine tuning the updated Phi-2 with flash-attn-2 produces very high loss > 2
5 participants