diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index d5b30d555dab..3dac6d8e4816 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -31,6 +31,7 @@ import torch from ..utils import is_torch_flex_attn_available +from ..utils.import_utils import _torch_version if is_torch_flex_attn_available(): @@ -60,8 +61,16 @@ def __init__(self): """ Initialize or update the singleton instance. """ - if self._is_flex_compiled is False: - self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor") + if not self._is_flex_compiled: + # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may + # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" + # see https://github.com/pytorch/pytorch/issues/146260 for training + if _torch_version == "2.6.0": + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" + ) + else: + self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) self._is_flex_compiled = True def __call__(self):