Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch

from ..utils import is_torch_flex_attn_available
from ..utils.import_utils import _torch_version


if is_torch_flex_attn_available():
Expand Down Expand Up @@ -60,8 +61,16 @@ def __init__(self):
"""
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
if not self._is_flex_compiled:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
if _torch_version == "2.6.0":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to describe why this is the case, e.g. by linking to the issue with a small description?

Might be also nice to use/create something like

def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested on my end this should fix training!

self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
else:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._is_flex_compiled = True

def __call__(self):
Expand Down