Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi] Add support for sdpa #29108

Merged
merged 1 commit into from
Feb 20, 2024
Merged

[Phi] Add support for sdpa #29108

merged 1 commit into from
Feb 20, 2024

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Feb 19, 2024

What does this PR do?

Adding support for SDPA to Phi (See #28005)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fxmarty @ArthurZucker

@hackyon
Copy link
Contributor Author

hackyon commented Feb 19, 2024

Hey @gugarosa @ArthurZucker @younesbelkada 👋

I'm looking for more places to add support for SDPA and figured Phi-2 could be a good one.

Been reading up on the issues regarding attention overflow for Phi-2 (#28673, #28488), and I think SPDA would probably be affected by it as well (if it chooses the FA kernels). So, this issue is dependent on #28673.

I think we should at least issue a warning in SDPA attention if the flash attention is available and dtype == float16 or autocast dtype == float16 (not sure if SDPA will try to autocast to fp16 under the hood).

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean already!
I think we can move forward by merging the PR and iteratively see if users get some issues with fp16 / autocast - there is also a chance things get fixed if one uses SDPA - since all SDPA tests pass in the CI with this PR, IMO we can merge !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@hackyon hackyon marked this pull request as ready for review February 20, 2024 03:14
@hackyon
Copy link
Contributor Author

hackyon commented Feb 20, 2024

Thanks for the review!

Sounds good, marking this for ready. Let me know if/when there's any follow-up work and I'd happy to take a stab at it.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM let's make sure the logits pass and the integration tests pass🤗

@hackyon
Copy link
Contributor Author

hackyon commented Feb 20, 2024

Yup, all the common model and integration tests pass:

PASSED tests/models/phi/test_modeling_phi.py::PhiModelTest::test_eager_matches_sdpa_generate
PASSED tests/models/phi/test_modeling_phi.py::PhiModelTest::test_eager_matches_sdpa_inference_0_float16
PASSED tests/models/phi/test_modeling_phi.py::PhiModelTest::test_eager_matches_sdpa_inference_1_bfloat16
PASSED tests/models/phi/test_modeling_phi.py::PhiModelTest::test_eager_matches_sdpa_inference_2_float32
...
PASSED tests/models/phi/test_modeling_phi.py::PhiIntegrationTest::test_model_phi_1_5_logits
PASSED tests/models/phi/test_modeling_phi.py::PhiIntegrationTest::test_model_phi_1_logits
PASSED tests/models/phi/test_modeling_phi.py::PhiIntegrationTest::test_model_phi_2_logits
PASSED tests/models/phi/test_modeling_phi.py::PhiIntegrationTest::test_phi_2_generation

@ArthurZucker
Copy link
Collaborator

Just ran RUN_SLOW=1 pytest tests/models/phi:

_______________________________________________ PhiModelTest.test_eager_matches_sdpa_inference_1_bfloat16 ________________________________________________

a = (<tests.models.phi.test_modeling_phi.PhiModelTest testMethod=test_eager_matches_sdpa_inference_1_bfloat16>,), kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)

../miniconda3/envs/py39/lib/python3.9/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/test_modeling_common.py:3698: in test_eager_matches_sdpa_inference
    self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
E   AssertionError: False is not true : padding_side=left, use_mask=False, batch_size=5, enable_kernels=True: mean relative difference: 1.154e-02, torch atol = 0.01, torch rtol = 0.03
E   padding_side=left, use_mask=True, batch_size=5, enable_kernels=True: mean relative difference: 1.129e-02, torch atol = 0.01, torch rtol = 0.03
E   padding_side=right, use_mask=False, batch_size=5, enable_kernels=True: mean relative difference: 1.154e-02, torch atol = 0.01, torch rtol = 0.03
E   padding_side=right, use_mask=True, batch_size=5, enable_kernels=True: mean relative difference: 1.129e-02, torch atol = 0.01, torch rtol = 0.03

I think it is acceptable

@ArthurZucker ArthurZucker merged commit b8b1647 into huggingface:main Feb 20, 2024
18 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks @hackyon 😉

@hackyon
Copy link
Contributor Author

hackyon commented Feb 20, 2024

Thanks!

I just ran the PhiModelTest again on my server again and it still passes, so it seems like a config issue :/ I'm running these tests/benchmarks on Paperspace ml-in-a-box instance with A100 GPUs.

Let me know if you have recommendations on any better setup/config to use.

@hackyon hackyon deleted the sdpa-phi branch February 20, 2024 14:05
@ArthurZucker
Copy link
Collaborator

No worries might be flaky as well 1e-2 is alright I think I have torch nightly as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants