From a026a6aed22edcb824f162eb183eb65a74ab6f55 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 4 Apr 2025 11:29:24 +0100 Subject: [PATCH 1/5] adding compile kwarg for torch 2.6 --- src/transformers/integrations/flex_attention.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index b0a054998c8a..42bbd4c23d6c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -11,6 +11,7 @@ year = {2024} } """ + # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # @@ -31,6 +32,7 @@ import torch from ..utils import is_torch_flex_attn_available +from ..utils.import_utils import _torch_version if is_torch_flex_attn_available(): @@ -64,7 +66,12 @@ def __init__(self): Initialize or update the singleton instance. """ if self._is_flex_compiled is False: - self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) + if _torch_version == "2.6.0": + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs" + ) + else: + self._compiled_flex_attention = torch.compile(flex_attention, dynamic=True) self._is_flex_compiled = True def __call__(self): From 0d8ba7236e9d94e10494f35091a9e89aa734a879 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 4 Apr 2025 16:20:57 +0100 Subject: [PATCH 2/5] fixing dynamic --- src/transformers/integrations/flex_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 42bbd4c23d6c..e15d8e4c4301 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -65,13 +65,13 @@ def __init__(self): """ Initialize or update the singleton instance. """ - if self._is_flex_compiled is False: + if not self._is_flex_compiled: if _torch_version == "2.6.0": self._compiled_flex_attention = torch.compile( - flex_attention, dynamic=True, mode="max-autotune-no-cudagraphs" + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" ) else: - self._compiled_flex_attention = torch.compile(flex_attention, dynamic=True) + self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) self._is_flex_compiled = True def __call__(self): From f13dc1bde80b5f82dd575cbbb61b9e6b31418678 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 7 Apr 2025 12:07:28 +0100 Subject: [PATCH 3/5] addressing comment --- src/transformers/integrations/flex_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index e15d8e4c4301..438323f39990 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -11,7 +11,6 @@ year = {2024} } """ - # coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # @@ -66,6 +65,9 @@ def __init__(self): Initialize or update the singleton instance. """ if not self._is_flex_compiled: + # In PyTorch 2.6.0, there's a known issue flex attention compilation which may + # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" + # see https://github.com/pytorch/pytorch/issues/146260 if _torch_version == "2.6.0": self._compiled_flex_attention = torch.compile( flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" From c636e59170fb1a6918c7c227e77e1680cf173988 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 7 Apr 2025 12:10:05 +0100 Subject: [PATCH 4/5] typo --- src/transformers/integrations/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 438323f39990..344e96840d8c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -65,7 +65,7 @@ def __init__(self): Initialize or update the singleton instance. """ if not self._is_flex_compiled: - # In PyTorch 2.6.0, there's a known issue flex attention compilation which may + # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # see https://github.com/pytorch/pytorch/issues/146260 if _torch_version == "2.6.0": From 48f8964d25086619c58463729375bdb0b9d841f4 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 7 Apr 2025 22:39:31 +0200 Subject: [PATCH 5/5] Update src/transformers/integrations/flex_attention.py --- src/transformers/integrations/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 344e96840d8c..879a643590d8 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -67,7 +67,7 @@ def __init__(self): if not self._is_flex_compiled: # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" - # see https://github.com/pytorch/pytorch/issues/146260 + # see https://github.com/pytorch/pytorch/issues/146260 for training if _torch_version == "2.6.0": self._compiled_flex_attention = torch.compile( flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"