-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Phi] Add support for sdpa #29108
[Phi] Add support for sdpa #29108
Conversation
Hey @gugarosa @ArthurZucker @younesbelkada 👋 I'm looking for more places to add support for SDPA and figured Phi-2 could be a good one. Been reading up on the issues regarding attention overflow for Phi-2 (#28673, #28488), and I think SPDA would probably be affected by it as well (if it chooses the FA kernels). So, this issue is dependent on #28673. I think we should at least issue a warning in SDPA attention if the flash attention is available and dtype == float16 or autocast dtype == float16 (not sure if SDPA will try to autocast to fp16 under the hood). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very clean already!
I think we can move forward by merging the PR and iteratively see if users get some issues with fp16 / autocast - there is also a chance things get fixed if one uses SDPA - since all SDPA tests pass in the CI with this PR, IMO we can merge !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for the review! Sounds good, marking this for ready. Let me know if/when there's any follow-up work and I'd happy to take a stab at it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM let's make sure the logits pass and the integration tests pass🤗
Yup, all the common model and integration tests pass:
|
Just ran _______________________________________________ PhiModelTest.test_eager_matches_sdpa_inference_1_bfloat16 ________________________________________________
a = (<tests.models.phi.test_modeling_phi.PhiModelTest testMethod=test_eager_matches_sdpa_inference_1_bfloat16>,), kw = {}
@wraps(func)
def standalone_func(*a, **kw):
> return func(*(a + p.args), **p.kwargs, **kw)
../miniconda3/envs/py39/lib/python3.9/site-packages/parameterized/parameterized.py:620:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/test_modeling_common.py:3698: in test_eager_matches_sdpa_inference
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
E AssertionError: False is not true : padding_side=left, use_mask=False, batch_size=5, enable_kernels=True: mean relative difference: 1.154e-02, torch atol = 0.01, torch rtol = 0.03
E padding_side=left, use_mask=True, batch_size=5, enable_kernels=True: mean relative difference: 1.129e-02, torch atol = 0.01, torch rtol = 0.03
E padding_side=right, use_mask=False, batch_size=5, enable_kernels=True: mean relative difference: 1.154e-02, torch atol = 0.01, torch rtol = 0.03
E padding_side=right, use_mask=True, batch_size=5, enable_kernels=True: mean relative difference: 1.129e-02, torch atol = 0.01, torch rtol = 0.03 I think it is acceptable |
Thanks @hackyon 😉 |
Thanks! I just ran the PhiModelTest again on my server again and it still passes, so it seems like a config issue :/ I'm running these tests/benchmarks on Paperspace ml-in-a-box instance with A100 GPUs. Let me know if you have recommendations on any better setup/config to use. |
No worries might be flaky as well 1e-2 is alright I think I have torch nightly as well |
What does this PR do?
Adding support for SDPA to Phi (See #28005)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@fxmarty @ArthurZucker