Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention bugs with Mistral and Falcon #27625

Merged
merged 6 commits into from
Nov 21, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Nov 21, 2023

This PR fixes some important bugs in the Mistral and Falcon integration.

#26933 broke flash attention for Falcon due to the modification of the layout

The following tests were not passing:

FAILED tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_flash_attn_2_generate_padding_right - AssertionError: ValueError not raised
FAILED tests/models/mistral/test_modeling_mistral.py::MistralModelTest::test_flash_attn_2_inference_padding_right - AssertionError: ValueError not raised

FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_generate_left_padding - RuntimeError: CUDA error: device-side assert triggered
FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_generate_padding_right - RuntimeError: CUDA error: device-side assert triggered
FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_generate_use_cache - RuntimeError: CUDA error: device-side assert triggered
FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_inference - RuntimeError: CUDA error: device-side assert triggered
FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_flash_attn_2_inference_padding_right - RuntimeError: CUDA error: device-side assert triggered

and Falcon with FA2 is not really usable on main due to an error in the shape (currently [batch_size, num_head, seqlen, head_dim] instead of the required [batch_size, seqlen, num_head, head_dim].

@@ -838,7 +838,7 @@ def forward(
attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled
and past_key_values is not None
and use_cache
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the first autoregressive pass, past_key_values is None. In the following passes, we may have masks as

[[1, 1, 0, 0, 1],
[1, 1, 1, 1, 1]]

and is_padding_right is wrongfully evaluated to False.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean ! Thanks !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

model.save_pretrained(tmpdirname)

dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Mistral apparently does not support right padding + use_cache with FA2.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works but you'll get terrible results because the cache will cut the non padded values first

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

I just have a small Q about the tests

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 21, 2023

Hi @amyeroberts thank you for the review! I added by mistake test_flash_attn_2_generate_use_cache in test_modeling_llama.py while it was meant to be added in test_modeling_mistral.py, hence the confusion, apology!

@amyeroberts
Copy link
Collaborator

@fxmarty Thanks for clarifying. Tbh, I'm still a bit confused with the tests - it's not clear to me how this explicitly tests for the cache as use_cache isn't set anywhere 😅

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 21, 2023

@amyeroberts good catch indeed... I just checked, we are going here that sets use_cache=True:

model_kwargs["use_cache"] = generation_config.use_cache

and uses

self.use_cache = kwargs.pop("use_cache", True)

@fxmarty fxmarty merged commit 82cc0a7 into huggingface:main Nov 21, 2023
@amyeroberts
Copy link
Collaborator

amyeroberts commented Nov 21, 2023

@fxmarty OK - thanks for explaining! As a follow up, could you add use_cache=True explicitly into the tests? This way it's clearer for anyone who sees the code and isn't subject to silently not being tested anymore if the configs or config handling changes

@fxmarty
Copy link
Contributor Author

fxmarty commented Nov 21, 2023

For sure @amyeroberts I will ping you there. Sorry I should have waited before merging..

@amyeroberts
Copy link
Collaborator

@fxmarty No worries! It doesn't affect the functionality of this PR so it's fine to be done separately :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants