Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral&Mixtral]Add sliding window param to sdpa after torch 2.2.0 #29220

Closed
wants to merge 55 commits into from

Conversation

ehuaa
Copy link
Contributor

@ehuaa ehuaa commented Feb 22, 2024

What does this PR do?

Add sliding window param as described in #28980, and the solution is check if torch version greater than or equal as torch 2.2.0 (release version) , and add sliding window to attention_mask if it's true.
Tests of sdpa and sliding window attention with torch 2.2.0 has been passed with
RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/models/mistral/test_modeling_mistral.py
in my local environment.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ √] Did you read the contributor guideline,
    Pull Request section?
  • [ √] Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • [ √] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @fxmarty

@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 22, 2024

I modified mixtral model as well, because of they have same sliding window structure here.
It's weird that when i modified qwen2 model with the same code ,the tests went wrong, so i reverted the modification in modeling_qwen2 of passing sliding window param to sdpa attention mask and only fix the non-contigous issue.
Do you guys have any idea why sliding window mask is not worked with sdpa in qwen2, maybe it's too strict for the numerical issue?

qwen2 failed with these tests below:
=================================================================================================== short test summary info ===================================================================================================
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_cpu_offload - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_disk_offload_bin - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_disk_offload_safetensors - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_feed_forward_chunking - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_generate_continue_from_past_key_values - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_left_padding_compatibility - AssertionError: False is not true
FAILED tests/models/qwen2/test_modeling_qwen2.py::Qwen2ModelTest::test_model_parallelism - AssertionError: False is not true
=================================================================================== 7 failed, 83 passed, 53 skipped, 114 warnings in 29.42s =======

@ehuaa ehuaa changed the title [Mistral&Mixtral&Qwen2]Add sliding window param to sdpa after torch==2.2.0 [Mistral&Mixtra]Add sliding window param to sdpa after torch==2.2.0 Feb 22, 2024
@ehuaa ehuaa changed the title [Mistral&Mixtra]Add sliding window param to sdpa after torch==2.2.0 [Mistral&Mixtral]Add sliding window param to sdpa after torch 2.2.0 Feb 22, 2024
@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 22, 2024

Hey @fxmarty @ArthurZucker ,
I have added sliding window param as described in #28980, can you have a look at this PR? Thanks.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks alright but let's not add the unrelated change

src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/qwen2/modeling_qwen2.py Outdated Show resolved Hide resolved
@ehuaa ehuaa requested a review from ArthurZucker February 24, 2024 06:46
@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 24, 2024

Looks alright but let's not add the unrelated change

I have deleted the unrelated changes and please review them again. The failed test are because of
error running git clone "git@github.com:huggingface/transformers.git": exit status 128, which is not related to this PR.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small nit but LGTM otherwise.
Let's just declare _is_torch_version_greater_or_equal_than_2_2_0 in the src/transformers/utils/import_utils.py file

src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
@ehuaa ehuaa requested a review from ArthurZucker February 27, 2024 07:20
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good to go

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's fix the conflicts and we should be able to merge.
EDIT: let's add a small test in the integration tests!

@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 27, 2024

LGTM, let's fix the conflicts and we should be able to merge. EDIT: let's add a small test in the integration tests!

I have resolved the conflict, and the test you mentioned above i think we can use the test you wrote before, https://github.com/huggingface/transformers/blob/main/tests/models/mistral/test_modeling_mistral.py#L534
when you upgrade your torch version of huggingface environment to 2.2, and enable RUN_SLOW=yes, this test can check if this PR is right!
Additionaly, when we convert the tensor to contigous, it can solve the problem even with torch 2.1.1, so I think the test you wrote above is enough for testing!

@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 27, 2024

When i tried to add a new test to check if sdpa work with sliding_window, the test below is failed.

I just modified your original test_model_7b_logits with a longer input with size of 4097 to check sliding_window feature, it failes as below:
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 1 / 4097 (0.0%)
E Greatest absolute difference: 0.20025348663330078 at index (0, 4096) (up to 0.01 allowed)
E Greatest relative difference: 0.06286640465259552 at index (0, 4096) (up to 0.01 allowed)
which i think is also a problem with older versions of transformers. @ArthurZucker

@ehuaa
Copy link
Contributor Author

ehuaa commented Feb 27, 2024

When i tried to add a new test to check if sdpa work with sliding_window, the test below is failed.

I just modified your original test_model_7b_logits with a longer input with size of 4097 to check sliding_window feature, it failes as below: E AssertionError: Tensor-likes are not close! E E Mismatched elements: 1 / 4097 (0.0%) E Greatest absolute difference: 0.20025348663330078 at index (0, 4096) (up to 0.01 allowed) E Greatest relative difference: 0.06286640465259552 at index (0, 4096) (up to 0.01 allowed) which i think is also a problem with older versions of transformers. @ArthurZucker

@slow
def test_model_7b_logits(self):
input_ids = [1] + [306, 338] * 2048
# input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
with torch.backends.cuda.sdp_kernel(
enable_flash=False,
enable_math=True,
enable_mem_efficient=False,
):
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto",attn_implementation="eager")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()

    input_ids = [1] + [306, 338] * 2048  
    model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto",attn_implementation="sdpa")
    input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
    with torch.no_grad():
        out1 = model(input_ids).logits.cpu()
    # Expected mean on dim = -1
    torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2)

    del model
    backend_empty_cache(torch_device)
    gc.collect()

@ehuaa ehuaa requested a review from ArthurZucker February 27, 2024 13:47
SunMarc and others added 4 commits February 27, 2024 09:58
…e#29264)

* Add compatibility with mps device

* fix

* typo and style
Co-authored-by: Joao Gante <joao@huggingface.co>
* [i18n-zh] Translate fsdp.md into Chinese

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

* apply suggestions from Fan-Lin

---------

Signed-off-by: windsonsea <haifeng.yao@daocloud.io>
ArthurZucker and others added 4 commits March 1, 2024 08:59
* [Whisper Tok] Update integration test

* make style
* Fix yolos processing

* Add back slow marker - protects for pycocotools in slow

* Slow decorator goes above copied from header
@psinger
Copy link

psinger commented Mar 1, 2024

Thanks for this PR.
A quick related question that is confusing me a bit: Doesn't that mean that Mistral and related models were buggy the last few versions, as sdpa is the default config, but it does not respect window_size?

SunMarc and others added 3 commits March 1, 2024 10:32
* Fix deprecated arg issue

* Trainer check too

* Check for dict or dataclass

* Simplify, make config always AcceleratorConfig

* Upstream to Trainer
)

* Correct zero division error in inverse sqrt scheduler

* default timescale to 10_000
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 2, 2024

Probably yes. That why it's a bit critical. Mistral only had flash_attention_2 at the beginning so I suspect people used this config explicitly. But after merging sdpa without sliding window if it was not explicitly specified yes, it would use full attention so would use more RAM, and maybe be more accurate.
I don't really recalling having this reported that is why it went unnoticed for quite a while

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing missing for me is a test to make sure that sdpa with sliding_window gives the same results as flash_attention_2 with sliding_window

@ehuaa
Copy link
Contributor Author

ehuaa commented Mar 2, 2024

The only thing missing for me is a test to make sure that sdpa with sliding_window gives the same results as flash_attention_2 with sliding_window

@ArthurZucker , Yes, so i test mistral-7b logits of sdpa with sliding_window and eager with sliding_window, and the results are not the same.
The reason why not using flash_attention_2 is i only have a gpu card of v100 which is not compatiable with flash_attentnion_2, can you test it in your own machine above or should i use huggingface CLI environment to test?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 2, 2024

Yes, so i test mistral-7b logits of sdpa with sliding_window and eager with sliding_window, and the results are not the same.

That's less important than sdpa vs flash!
You can add the slow test already in MistralIntegrationTest and I will run it on my machines

@ehuaa
Copy link
Contributor Author

ehuaa commented Mar 2, 2024

Yes, so i test mistral-7b logits of sdpa with sliding_window and eager with sliding_window, and the results are not the same.

That's less important than sdpa vs flash! You can add the slow test already in MistralIntegrationTest and I will run it on my machines

Ok, I'll upload a new slow test of sdpa vs flash later.

@ehuaa ehuaa marked this pull request as draft March 2, 2024 11:09
@ehuaa ehuaa closed this Mar 2, 2024
@ehuaa
Copy link
Contributor Author

ehuaa commented Mar 2, 2024

After I git push, there're something weird to the changed files above so i closed this pr and try to open a new pull request. @ArthurZucker

@ArthurZucker
Copy link
Collaborator

alright! It think this fix is important, let's try to have this ready! Ping me on the next PR and link it with this one as well to get the full conv!

@ehuaa
Copy link
Contributor Author

ehuaa commented Mar 5, 2024

alright! It think this fix is important, let's try to have this ready! Ping me on the next PR and link it with this one as well to get the full conv!

#29407 @ArthurZucker Please review this pr you mentioned above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.