Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPTNeoX] Flex Attention + Refactor #34896

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_flex_attn_available,
is_torch_greater_or_equal,
is_torch_sdpa_available,
is_torch_xla_available,
Expand Down Expand Up @@ -1338,6 +1339,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa = False

# Flex Attention support
_supports_flex_attn = False

# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False
Expand Down Expand Up @@ -1544,6 +1548,10 @@ def _autoset_attn_implementation(
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
if cls._supports_flex_attn:
message += (
', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)'
)
raise ValueError(message + ".")

# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
Expand Down Expand Up @@ -1578,6 +1586,8 @@ def _autoset_attn_implementation(
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
Expand Down Expand Up @@ -1774,7 +1784,7 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra
"""
Checks the availability of SDPA for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "sdpa" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_sdpa:
Expand All @@ -1799,6 +1809,40 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> Pretra
config._attn_implementation = "sdpa"
return config

@classmethod
def _check_and_enable_flex_attn(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
Checks the availability of Flex Attention for a given model.
If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flex_attention" so that the model can initialize the correct attention module.
"""
if hard_check_only:
if not cls._supports_flex_attn:
# TODO: add contribution notice?
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch's flex_attention."
" If you believe this error is a bug, please open an issue in Transformers GitHub repository"
' and load your model with the argument `attn_implementation="eager"` meanwhile.'
' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
Comment on lines +1821 to +1827
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely linking to #34809

if not is_torch_flex_attn_available():
raise ImportError(
"PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
)

if not is_torch_flex_attn_available() or not cls._supports_flex_attn:
return config

# TODO check for more edge cases as done in the other implementations
# _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
# if _is_bettertransformer:
# return config
Comment on lines +1836 to +1839
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just let it there since I'm not familiar with better transformers and if there needs to be a check or smthn


if not hard_check_only:
config._attn_implementation = "flex_attention"

return config

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
Expand Down
Loading