Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better SDPA unmasking implementation #29318

Merged
merged 9 commits into from
Feb 28, 2024

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Feb 27, 2024

As @ArthurZucker improved the unmasking for SDPA for mem-efficient code path let's do so for all archs using SDPA #27931

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 27, 2024

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=3 pytest tests/ -k "test_eager_matches_sdpa_inference" -s -vvvvv passes except for qwen2 (but it is unrelated, see #28436 (comment))

Comment on lines 204 to 205
if expanded_mask.dtype == torch.bool:
raise ValueError("AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some models (gpt bigcode) use bool tensors, but Arthur's implem can't work for that dtype.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we cast it and replace the min with 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I expect the cast to be done in the modeling file (explicit).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine by me

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for propagating the changes

Comment on lines 204 to 205
if expanded_mask.dtype == torch.bool:
raise ValueError("AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we cast it and replace the min with 0?

src/transformers/modeling_attn_mask_utils.py Show resolved Hide resolved
src/transformers/modeling_attn_mask_utils.py Show resolved Hide resolved
@fxmarty fxmarty merged commit 49204c1 into huggingface:main Feb 28, 2024
20 checks passed
itazap pushed a commit that referenced this pull request May 14, 2024
* better unmask imple

* comment

* typo

* bug report pytorch

* cleanup

* fix import

* add back example

* retrigger ci

* come on
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants