File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change 3131import torch
3232
3333from ..utils import is_torch_flex_attn_available
34+ from ..utils .import_utils import _torch_version
3435
3536
3637if is_torch_flex_attn_available ():
@@ -60,8 +61,16 @@ def __init__(self):
6061 """
6162 Initialize or update the singleton instance.
6263 """
63- if self ._is_flex_compiled is False :
64- self ._compiled_flex_attention = torch .compile (flex_attention , backend = "inductor" )
64+ if not self ._is_flex_compiled :
65+ # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
66+ # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
67+ # see https://github.com/pytorch/pytorch/issues/146260 for training
68+ if _torch_version == "2.6.0" :
69+ self ._compiled_flex_attention = torch .compile (
70+ flex_attention , dynamic = False , mode = "max-autotune-no-cudagraphs"
71+ )
72+ else :
73+ self ._compiled_flex_attention = torch .compile (flex_attention , dynamic = False )
6574 self ._is_flex_compiled = True
6675
6776 def __call__ (self ):
You can’t perform that action at this time.
0 commit comments