Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTNeoX] Flex Attention + Refactor #34896

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Nov 23, 2024

What does this PR do?

Adds flex attention and the refactor according to #34809

However, I discovered several issues in the current version of gemma2 (#34282):

  • It seems like that flex attention needs a transpose afterwards like sdpa
  • Loading flex attn with from pretrained didn't work and hence, current tests use another attn implementation (eager or sdpa not sure again)
  • Tests could gain from similar tests like sdpa :D for now it's a bit of a hassle to always have some integration test added when it could be a more general test for all subsequent models
  • I'm not familiar with better transformers or limitations of flex attn --> added some todos in case we need to check in
  • Flex attn doesn't support dropout (or maybe I've overlooked something)
  • Setting model.config._attn_implementation = ... should be tracked somewhere and checked for sanity as done the first time - for now it silently overwrites and could cause some ugly errors (tested with changing to flash attention 2 while not having fa2 installed)
  • Documentation should be added somewhere (prolly perf or something else)

So tbh, I'm not sure whether to split this PR into several ones, e.g. a gemma fix, general loading, general tests, docs, and then subsequent models, or not

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A collection of comments which partially show the issues I listed above

Comment on lines +1821 to +1827
# TODO: add contribution notice?
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely linking to #34809

Comment on lines +1836 to +1839
# TODO check for more edge cases as done in the other implementations
# _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
# if _is_bettertransformer:
# return config
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just let it there since I'm not familiar with better transformers and if there needs to be a check or smthn

Comment on lines +367 to +369
# TODO check if some bugs cause push backs on the exact version
# NOTE: We require torch>=2.5.0 as it is the first release
return version.parse(_torch_version) >= version.parse("2.5.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here unsure if we will encounter bugs ;)

Comment on lines 462 to 484
@slow
def test_lm_generate_flex_attn_gptneox(self):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
for checkpointing in [True, False]:
model = GPTNeoXForCausalLM.from_pretrained(
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention"
)

if checkpointing:
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_disable()
model.to(torch_device)

inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
# The hub repo. is updated on 2023-04-04, resulting in poor outputs.
# See: https://github.com/huggingface/transformers/pull/24193
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure"

output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20)
output_str = tokenizer.batch_decode(output_ids)[0]

self.assertEqual(output_str, expected_output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to have common tests in the future (instead)

Comment on lines -260 to -262
if key_length > self.bias.shape[-1]:
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed since we handle this outside the forward for the attention masks; kept the buffers for BC so weights loading won't complain.

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is missing in gemma2 as well: It's using the config but unsure if that's sufficient.

Comment on lines +378 to +386
if (
self.training
and self.config.attention_dropout > 0
and self.config._attn_implementation == "flex_attention"
):
logger.warning_once(
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}"
)
attention_type = "eager"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No dropout in flex attn

Comment on lines +266 to +267
# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transpose which is (possibly) missing in gemma2

Comment on lines +175 to +187
input_dtype = query.dtype
if input_dtype == torch.float32:
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)

attention_dropout = attention_dropout if training else 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This peft check also seems to be missing in gemma2

@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

CI failings seem unrelated, flaky tests (e.g. XLM, Qwen2VL)

@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

Possible TODO -> fallback to eager when using head mask in fa2, sdpa + add head mask in flex attention (should be possible via score mod)

Edit: Added now

@dame-cell
Copy link

dame-cell commented Nov 23, 2024

Yes I was testing it for Gemma as well it needed a transpose at the end as well

If you don't mind could you check the pull request i did for gemma seems I keep failing some tests

Also the gemma 2 now supports new stuff in the configuration which confused me a lot
The attn logit soft capping

Also the model.config._attn_implementation is not really implemented correctly for example it does not correctly uses the correct attn upon choosing one

Still working on the gemma flex attention pr might help with the docs as well

@vasqu
Copy link
Contributor Author

vasqu commented Nov 23, 2024

@dame-cell I'll take a look tomorrow! I'm busted for today :)

But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding.

Edit: One last thing to change would be to add _supports_flex_attn = True then like done for sdpa, fa2

@dame-cell
Copy link

@dame-cell I'll take a look tomorrow! I'm busted for today :)

But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding.

Hmmm ohh I get it I see thanks for letting me know 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants