-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPTNeoX
] Flex Attention + Refactor
#34896
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A collection of comments which partially show the issues I listed above
# TODO: add contribution notice? | ||
raise ValueError( | ||
f"{cls.__name__} does not support an attention implementation through torch's flex_attention." | ||
" If you believe this error is a bug, please open an issue in Transformers GitHub repository" | ||
' and load your model with the argument `attn_implementation="eager"` meanwhile.' | ||
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely linking to #34809
# TODO check for more edge cases as done in the other implementations | ||
# _is_bettertransformer = getattr(cls, "use_bettertransformer", False) | ||
# if _is_bettertransformer: | ||
# return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just let it there since I'm not familiar with better transformers and if there needs to be a check or smthn
# TODO check if some bugs cause push backs on the exact version | ||
# NOTE: We require torch>=2.5.0 as it is the first release | ||
return version.parse(_torch_version) >= version.parse("2.5.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here unsure if we will encounter bugs ;)
@slow | ||
def test_lm_generate_flex_attn_gptneox(self): | ||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped") | ||
for checkpointing in [True, False]: | ||
model = GPTNeoXForCausalLM.from_pretrained( | ||
"EleutherAI/pythia-410m-deduped", attn_implementation="flex_attention" | ||
) | ||
|
||
if checkpointing: | ||
model.gradient_checkpointing_enable() | ||
else: | ||
model.gradient_checkpointing_disable() | ||
model.to(torch_device) | ||
|
||
inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device) | ||
# The hub repo. is updated on 2023-04-04, resulting in poor outputs. | ||
# See: https://github.com/huggingface/transformers/pull/24193 | ||
expected_output = "My favorite food is a good old-fashioned, old-fashioned, old-fashioned.\n\nI'm not sure" | ||
|
||
output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=20) | ||
output_str = tokenizer.batch_decode(output_ids)[0] | ||
|
||
self.assertEqual(output_str, expected_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would love to have common tests in the future (instead)
if key_length > self.bias.shape[-1]: | ||
self._init_bias(key_length, device=key.device) | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed since we handle this outside the forward for the attention masks; kept the buffers for BC so weights loading won't complain.
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. | ||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. | ||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | ||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is missing in gemma2 as well: It's using the config but unsure if that's sufficient.
if ( | ||
self.training | ||
and self.config.attention_dropout > 0 | ||
and self.config._attn_implementation == "flex_attention" | ||
): | ||
logger.warning_once( | ||
f"Setting `attention_type` to `eager` because `dropout` is not supported in {attention_type}" | ||
) | ||
attention_type = "eager" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No dropout in flex attn
# Reshape outputs | ||
attn_output = attn_output.transpose(1, 2).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transpose which is (possibly) missing in gemma2
input_dtype = query.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
f" {target_dtype}." | ||
) | ||
|
||
query = query.to(target_dtype) | ||
key = key.to(target_dtype) | ||
value = value.to(target_dtype) | ||
|
||
attention_dropout = attention_dropout if training else 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This peft check also seems to be missing in gemma2
CI failings seem unrelated, flaky tests (e.g. XLM, Qwen2VL) |
Possible TODO -> fallback to eager when using head mask in fa2, sdpa + add head mask in flex attention (should be possible via score mod) Edit: Added now |
Yes I was testing it for Gemma as well it needed a transpose at the end as well If you don't mind could you check the pull request i did for gemma seems I keep failing some tests Also the gemma 2 now supports new stuff in the configuration which confused me a lot Also the model.config._attn_implementation is not really implemented correctly for example it does not correctly uses the correct attn upon choosing one Still working on the gemma flex attention pr might help with the docs as well |
@dame-cell I'll take a look tomorrow! I'm busted for today :) But as quick thing to let the loading be handled correctly look into my changes into the utils folder and modeling_utils. With those changes, loading should be handled correctly. Tbh, that's one of the main reasons why I think it might be better to split some PRs and get loading etc correctly first before we start adding. Edit: One last thing to change would be to add |
Hmmm ohh I get it I see thanks for letting me know 😀 |
What does this PR do?
Adds flex attention and the refactor according to #34809
However, I discovered several issues in the current version of gemma2 (#34282):
model.config._attn_implementation = ...
should be tracked somewhere and checked for sanity as done the first time - for now it silently overwrites and could cause some ugly errors (tested with changing to flash attention 2 while not having fa2 installed)So tbh, I'm not sure whether to split this PR into several ones, e.g. a gemma fix, general loading, general tests, docs, and then subsequent models, or not
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker