Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral&Mixtral]Add sliding window for sdpa #29407

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
97cff89
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
d0baf19
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
c9dacb8
revert add sliding window for qwen2 because of numerical error
ehuaa Feb 22, 2024
9c7bb07
Merge branch 'add_sliding_window_for_sdpa' of https://mirror.ghproxy.…
ehuaa Feb 22, 2024
f464d15
remove adding sliding_window param to qwen2 because of numerical error
ehuaa Feb 22, 2024
f4c21d0
fix style
ehuaa Feb 22, 2024
a9f1571
only add non-contigous mask to qwen2 due to numerical error
ehuaa Feb 22, 2024
8b84d68
revert non-contigous tensor modification
ehuaa Feb 24, 2024
6bee407
move torch version judgement to import_utils for usability
ehuaa Feb 27, 2024
cf225ff
Merge branch 'main' into add_sliding_window_for_sdpa
ehuaa Feb 27, 2024
9c4f0b0
delete deprecated is_flash_attn in import_utils.py
ehuaa Feb 27, 2024
773d8c8
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
4fab890
revert add sliding window for qwen2 because of numerical error
ehuaa Feb 22, 2024
074d47a
add sliding window param to sdpa after torch==2.2.0
ehuaa Feb 22, 2024
6972cdf
remove adding sliding_window param to qwen2 because of numerical error
ehuaa Feb 22, 2024
8611c2d
fix style
ehuaa Feb 22, 2024
510f24f
only add non-contigous mask to qwen2 due to numerical error
ehuaa Feb 22, 2024
cbfc413
revert non-contigous tensor modification
ehuaa Feb 24, 2024
6d590b0
move torch version judgement to import_utils for usability
ehuaa Feb 27, 2024
da327f7
upload a test for compare flash vs sdpa for sliding window in Mistral
ehuaa Mar 2, 2024
8171137
Merge branch 'add_sliding_window_for_sdpa' of https://mirror.ghproxy.…
ehuaa Mar 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_version_greater_or_equal_than_2_2_0,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1006,6 +1007,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None,
)
else:
# 4d mask is passed through the layers
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_version_greater_or_equal_than_2_2_0,
logging,
replace_return_docstrings,
)
Expand All @@ -60,6 +61,7 @@

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
Expand Down Expand Up @@ -1190,6 +1192,7 @@ def forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is that _prepare_4d_causal_attention_mask_for_sdpa seems to return None if attention_mask is None (which is the case in the test) while if we actually want to use sliding we need to return the full causal mask. cc @fxmarty

)
else:
# 4d mask is passed through the layers
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torch_version_greater_or_equal_than_2_2_0,
is_torch_xpu_available,
is_torchaudio_available,
is_torchdistx_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,10 @@ def is_flash_attn_greater_or_equal_2_10():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")


def is_torch_version_greater_or_equal_than_2_2_0():
return version.parse(get_torch_version()) >= version.parse("2.2.0")


def is_torchdistx_available():
return _torchdistx_available

Expand Down
25 changes: 25 additions & 0 deletions tests/models/mistral/test_modeling_mistral.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗

Ok, and the test flash vs sdpa i submitted above cannot pass the tests, have you debugged with it? I'm also curious about the reason why it failed here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

and the generation test you mentioned above i think test_model_7b_long_prompt_sdpa is enough, it contains generation with sdpa and sliding window.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important

And i see that https://github.com/huggingface/transformers/blob/main/tests/models/gemma/test_modeling_gemma.py#L471 gemma has a similar sdpa logits test as i committed. I think they have passed this test, maybe it can help with the debug.

Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,31 @@ def test_model_7b_logits(self):
backend_empty_cache(torch_device)
gc.collect()

@slow
@require_flash_attn
@require_torch_sdpa
def test_model_7b_logits_long_with_sdpa_and_flash2(self):
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2"
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()

input_ids = [1] + [306, 338] * 2048
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ids = [1] + [306, 338] * 2048

model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa"
)
Comment on lines +491 to +501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2"
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa"
)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
)

I am getting an error because by default it seems to be float32.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this passes for me

input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)

with torch.no_grad():
out1 = model(input_ids).logits.cpu()
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make sure we test all logits not just the mean

Suggested change
torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2)
torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)

with this, the test is failing:

>       torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)
E       AssertionError: Tensor-likes are not close!
E       
E       Mismatched elements: 90967735 / 131104000 (69.4%)
E       Greatest absolute difference: 0.328125 at index (0, 2310, 338) (up to 0.0001 allowed)
E       Greatest relative difference: 114689.0 at index (0, 1267, 4581) (up to 0.0001 allowed)


del model
backend_empty_cache(torch_device)
gc.collect()

@slow
def test_model_7b_generation(self):
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
Expand Down