-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mistral&Mixtral]Add sliding window for sdpa #29407
Changes from all commits
97cff89
d0baf19
c9dacb8
9c7bb07
f464d15
f4c21d0
a9f1571
8b84d68
6bee407
cf225ff
9c4f0b0
773d8c8
4fab890
074d47a
6972cdf
8611c2d
510f24f
cbfc413
6d590b0
da327f7
8171137
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ok, and the test flash vs sdpa i submitted above cannot pass the tests, have you debugged with it? I'm also curious about the reason why it failed here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
and the generation test you mentioned above i think test_model_7b_long_prompt_sdpa is enough, it contains generation with sdpa and sliding window. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
And i see that https://github.com/huggingface/transformers/blob/main/tests/models/gemma/test_modeling_gemma.py#L471 gemma has a similar sdpa logits test as i committed. I think they have passed this test, maybe it can help with the debug. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -483,6 +483,31 @@ def test_model_7b_logits(self): | |||||||||||||||||||||||||||||||||||||||||||||
backend_empty_cache(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||||
gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||||
@require_flash_attn | ||||||||||||||||||||||||||||||||||||||||||||||
@require_torch_sdpa | ||||||||||||||||||||||||||||||||||||||||||||||
def test_model_7b_logits_long_with_sdpa_and_flash2(self): | ||||||||||||||||||||||||||||||||||||||||||||||
input_ids = [1] + [306, 338] * 2048 | ||||||||||||||||||||||||||||||||||||||||||||||
model = MistralForCausalLM.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2" | ||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||||||||||||||||||||||||||||||||||||||||||||||
with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||
out = model(input_ids).logits.cpu() | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
input_ids = [1] + [306, 338] * 2048 | ||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
model = MistralForCausalLM.from_pretrained( | ||||||||||||||||||||||||||||||||||||||||||||||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | ||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+491
to
+501
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I am getting an error because by default it seems to be float32. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this passes for me |
||||||||||||||||||||||||||||||||||||||||||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||
out1 = model(input_ids).logits.cpu() | ||||||||||||||||||||||||||||||||||||||||||||||
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2) | ||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make sure we test all logits not just the mean
Suggested change
with this, the test is failing: > torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 90967735 / 131104000 (69.4%)
E Greatest absolute difference: 0.328125 at index (0, 2310, 338) (up to 0.0001 allowed)
E Greatest relative difference: 114689.0 at index (0, 1267, 4581) (up to 0.0001 allowed) |
||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
del model | ||||||||||||||||||||||||||||||||||||||||||||||
backend_empty_cache(torch_device) | ||||||||||||||||||||||||||||||||||||||||||||||
gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
@slow | ||||||||||||||||||||||||||||||||||||||||||||||
def test_model_7b_generation(self): | ||||||||||||||||||||||||||||||||||||||||||||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue here is that
_prepare_4d_causal_attention_mask_for_sdpa
seems to return None ifattention_mask
isNone
(which is the case in the test) while if we actually want to use sliding we need to return the full causal mask. cc @fxmarty